You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2018/06/26 08:58:09 UTC

madlib git commit: DT: Add impurity variable importance

Repository: madlib
Updated Branches:
  refs/heads/master 69ab239b8 -> b8031a03f


DT: Add impurity variable importance

JIRA: MADLIB-1205

Brieman et. al. [1] describe a "gini importance" measure that can be
computed for a single decision tree. This measure is the impurity
decrease produced by any given feature in a node, accumulated over the
whole tree. Surrogates can also be added to this by scaling the impurity
decrease with the adjusted surrogate agreement.

This commit adds this importance measure for all the impurity functions
(hence the term impurity importance instead of gini importance).

[1] https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#giniimp

Closes #277

Co-authored-by: Nandish Jayaram <nj...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/b8031a03
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/b8031a03
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/b8031a03

Branch: refs/heads/master
Commit: b8031a03f0bdb4481b3470831a07fc167b5d46ec
Parents: 69ab239
Author: Rahul Iyer <ri...@apache.org>
Authored: Tue May 29 18:09:18 2018 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Tue Jun 26 01:56:21 2018 -0700

----------------------------------------------------------------------
 src/modules/recursive_partitioning/DT_impl.hpp  |  91 +++++++-
 src/modules/recursive_partitioning/DT_proto.hpp |  23 ++
 .../recursive_partitioning/decision_tree.cpp    |  35 ++-
 .../recursive_partitioning/decision_tree.hpp    |   3 +-
 .../recursive_partitioning/decision_tree.py_in  | 209 +++++++++++-------
 .../recursive_partitioning/decision_tree.sql_in | 220 ++++++++++++-------
 .../test/decision_tree.sql_in                   |  33 ++-
 7 files changed, 428 insertions(+), 186 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/DT_impl.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_impl.hpp b/src/modules/recursive_partitioning/DT_impl.hpp
index 69d0817..d9477c9 100644
--- a/src/modules/recursive_partitioning/DT_impl.hpp
+++ b/src/modules/recursive_partitioning/DT_impl.hpp
@@ -116,13 +116,11 @@ template <class Container>
 inline
 void
 DecisionTree<Container>::bind(ByteStream_type& inStream) {
-
     inStream >> tree_depth
              >> n_y_labels
              >> max_n_surr
              >> is_regression
              >> impurity_type;
-
     size_t n_nodes = 0;
     size_t n_labels = 0;
     size_t max_surrogates = 0;
@@ -406,8 +404,8 @@ DecisionTree<Container>::impurityGain(const ColumnVector &combined_stats,
     double false_count = statWeightedCount(combined_stats.segment(stats_per_split, stats_per_split));
     double total_count = true_count + false_count;
 
-    if (true_count == 0 || false_count == 0) {
-        // no gain if all fall into one side
+    if (total_count == 0 || true_count == 0 || false_count == 0) {
+        // no gain if no tuples incoming or if all fall into one side
         return 0.;
     }
     double true_weight = true_count / total_count;
@@ -574,7 +572,6 @@ DecisionTree<Container>::expand(const Accumulator &state,
                         max_stats.segment(0, sps),   // true_stats
                         max_stats.segment(sps, sps)  // false_stats
                     );
-
             } else {
                 feature_indices(current) = FINISHED_LEAF;
             }
@@ -619,11 +616,11 @@ DecisionTree<Container>::pickSurrogates(
     // we use the *_agg_matrix to add every two alternate columns of the
     // *_stats matrix to create a forward and reverse agreement metric
     // for each split.
-    // eg. in cat_stats,
-    //      we add columns 1 and 3 to get the <= split agreement for 1st cat split.
-    //      we add columns 2 and 4 to get the > split agreement for 1st cat split.
-    //      we add columns 5 and 7 to get the <= split agreement for 2nd cat split.
-    //      we add columns 6 and 8 to get the > split agreement for 2nd cat split.
+    // eg. For cat_stats,
+    //      add columns 1 and 3 to get the <= split agreement for 1st cat split.
+    //      add columns 2 and 4 to get the > split agreement for 1st cat split.
+    //      add columns 5 and 7 to get the <= split agreement for 2nd cat split.
+    //      add columns 6 and 8 to get the > split agreement for 2nd cat split.
     ColumnVector fwd_agg_vec(4);
     fwd_agg_vec << 1, 0, 1, 0;
     ColumnVector rev_agg_vec(4);
@@ -1460,6 +1457,80 @@ DecisionTree<Container>::encodeIndex(const int &feature_index,
 
 template <class Container>
 inline
+void
+DecisionTree<Container>::computeVariableImportance(
+        ColumnVector &cat_var_importance,
+        ColumnVector &con_var_importance){
+
+    // stats_per_split
+    uint16_t sps = is_regression ? REGRESS_N_STATS :
+                                   static_cast<uint16_t>(n_y_labels + 1);
+
+    // loop through for each internal node and check the primary split and any
+    // surrogate splits
+    for (Index node_index = 0;
+            node_index < feature_indices.size() / 2;
+            node_index++){
+        if (isInternalNode(node_index)){
+            ColumnVector combined_stats(sps * 2);
+            combined_stats << predictions.row(trueChild(node_index)).transpose(),
+                              predictions.row(falseChild(node_index)).transpose();
+            double split_gain = impurityGain(combined_stats, sps);
+
+            // importance = impurity gain from split +
+            //                  impurity gain * adjusted agreement from
+            //                      surrogate split
+
+            // primary split contribution to importance
+            Index feat_index = feature_indices(node_index);
+            if(is_categorical(node_index)) {
+                assert(feat_index < cat_var_importance.size());
+                cat_var_importance(feat_index) += split_gain;
+            } else {
+                assert(feat_index < con_var_importance.size());
+                con_var_importance(feat_index) += split_gain;
+            }
+
+            // surrogate contribution to importance
+            if (max_n_surr > 0){
+                for (Index surr_count=0; surr_count < max_n_surr; surr_count++){
+                    Index surr_lookup_index = node_index * max_n_surr + surr_count;
+
+                    // surr_status == 0 implies non-existing surrogate
+                    if (surr_status(surr_lookup_index) == 0)
+                        break;
+
+                    uint64_t total_count =
+                        statCount(predictions.row(trueChild(node_index))) +
+                        statCount(predictions.row(falseChild(node_index)));
+
+                    // Adjusted agreement is defined as how much better does
+                    // the surrogate do compared to majority count. This value
+                    // is relative to the number of rows that the majority branch
+                    // would not predict correctly (minority count).
+                    uint64_t node_count = nodeCount(node_index);
+                    uint64_t maj_count = getMajorityCount(node_index);
+                    uint64_t min_count =  node_count - maj_count;
+
+                    double adj_agreement =
+                        (surr_agreement(surr_lookup_index) - maj_count) / min_count;
+                    Index surr_feat_index = surr_indices(surr_lookup_index);
+                    if (std::abs(surr_status(surr_lookup_index)) == 1){
+                        cat_var_importance[surr_feat_index] += split_gain *
+                                                                adj_agreement;
+                    } else {
+                        con_var_importance[surr_feat_index] += split_gain *
+                                                                adj_agreement;
+                    }
+                }
+            }
+        }
+    }
+}
+// -------------------------------------------------------------------------
+
+template <class Container>
+inline
 string
 DecisionTree<Container>::surr_display(
         ArrayHandle<text*> &cat_features_str,

http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/DT_proto.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_proto.hpp b/src/modules/recursive_partitioning/DT_proto.hpp
index 1852f69..6980326 100644
--- a/src/modules/recursive_partitioning/DT_proto.hpp
+++ b/src/modules/recursive_partitioning/DT_proto.hpp
@@ -29,6 +29,17 @@ using namespace dbal::eigen_integration;
 using std::vector;
 using std::string;
 
+
+// stats_per_split is the number of statistics needed to accumulate for a split.
+// This differs for classification and regression. Value for classification is a
+// function of number of response values and is computed at runtime whereas
+// value for regression is a constant.
+
+// For classification, accumulate the following:
+//  num of weighted tuples for each possible response and num of unweighted tuples
+
+// For regression, accumulate 4 values for evaulating a split:
+//      weight, weight * response, weight * response^2, num of unweighted rows
 const uint16_t REGRESS_N_STATS = 4u;
 
 template <class Container>
@@ -93,6 +104,13 @@ public:
     }
     Index trueChild(Index current) const { return 2 * current + 1; }
     Index falseChild(Index current) const { return 2 * current + 2; }
+    bool isInternalNode(const Index node_index) const {
+       int split_feat_index = feature_indices(node_index);
+       return (split_feat_index != NODE_NON_EXISTING &&
+                split_feat_index != IN_PROCESS_LEAF &&
+                split_feat_index != FINISHED_LEAF);
+    }
+
     double impurity(const ColumnVector & stats) const;
     double impurityGain(const ColumnVector &combined_stats,
                         const uint16_t &stats_per_split) const;
@@ -149,6 +167,9 @@ public:
                       const int &is_categorical,
                       const int &n_cat_features) const;
 
+    void computeVariableImportance(ColumnVector& cat_var_importance,
+                                   ColumnVector& con_var_importance);
+
     // attributes
     // dimension information
     uint16_type tree_depth; // 1 for root-only tree
@@ -172,6 +193,7 @@ public:
     IntegerVector_type feature_indices;
     // elements are of integer type for categorical
     ColumnVector_type feature_thresholds;
+
     // used as boolean array, 0 means continuous, otherwise categorical
     IntegerVector_type is_categorical;
 
@@ -207,6 +229,7 @@ public:
     Matrix_type predictions;   // used as integer if we do classification
 };
 
+
 // ------------------------------------------------------------------------
 // TreeAccumulator is used for collecting statistics during training the nodes
 // for a level of the decision tree. The same accumulator is also used for

http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/decision_tree.cpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/decision_tree.cpp b/src/modules/recursive_partitioning/decision_tree.cpp
index 5579cb5..42ea942 100644
--- a/src/modules/recursive_partitioning/decision_tree.cpp
+++ b/src/modules/recursive_partitioning/decision_tree.cpp
@@ -54,7 +54,6 @@ initialize_decision_tree::run(AnyType & args){
     std::string impurity_func_str = args[1].getAs<std::string>();
     uint16_t n_y_labels = args[2].getAs<uint16_t>();
     uint16_t max_n_surr = args[3].getAs<uint16_t>();
-
     if (is_regression_tree)
         n_y_labels = REGRESS_N_STATS;
     dt.rebind(1u, n_y_labels, max_n_surr, is_regression_tree);
@@ -174,10 +173,7 @@ compute_leaf_stats_transition::run(AnyType & args){
             }
         }
 
-        // For classification, we store for each split the number of weighted
-        // tuples for each possible response value and the number of unweighted
-        // tuples landing on that node.
-        // For regression, REGRESS_N_STATS determines the number of stats per split
+        // see DT_proto.hpp for explanation on stats_per_split
         uint16_t stats_per_split = dt.is_regression ?
             REGRESS_N_STATS : static_cast<uint16_t>(n_response_labels + 1);
         const bool weights_as_rows = args[9].getAs<bool>();
@@ -491,6 +487,30 @@ print_decision_tree::run(AnyType &args){
 }
 
 AnyType
+get_variable_importance::run(AnyType &args){
+    Tree dt = args[0].getAs<ByteString>();
+    const int n_cat_features = args[1].getAs<int>();
+    const int n_con_features = args[2].getAs<int>();
+
+    ColumnVector cat_var_importance = ColumnVector::Zero(n_cat_features);
+    ColumnVector con_var_importance = ColumnVector::Zero(n_con_features);
+    dt.computeVariableImportance(cat_var_importance, con_var_importance);
+
+    // Variable importance is scaled to represent a percentage. Even though
+    // the importance values are split between categorical and continuous, the
+    // percentages are relative to the combined set.
+   ColumnVector combined_var_imp(n_cat_features + n_con_features);
+   combined_var_imp << cat_var_importance, con_var_importance;
+
+    // Avoid divide by zero by adding a small number.
+    double total_var_imp = combined_var_imp.sum();
+    double VAR_IMP_EPSILON = 1e-6;
+    combined_var_imp *=  (100.0 / (total_var_imp + VAR_IMP_EPSILON));
+
+    return combined_var_imp;
+}
+
+AnyType
 display_text_tree::run(AnyType &args){
     Tree dt = args[0].getAs<ByteString>();
     ArrayHandle<text*> cat_feature_names = args[1].getAs<ArrayHandle<text*> >();
@@ -511,8 +531,8 @@ display_text_tree::run(AnyType &args){
 void mark_subtree_removal_recur(MutableTree &dt, int me) {
     if (me < dt.predictions.rows() &&
             dt.feature_indices(me) != dt.NODE_NON_EXISTING) {
-        int left = static_cast<int>(dt.trueChild(static_cast<Index>(me))),
-            right = static_cast<int>(dt.falseChild(static_cast<Index>(me)));
+        int left = static_cast<int>(dt.trueChild(static_cast<Index>(me)));
+        int right = static_cast<int>(dt.falseChild(static_cast<Index>(me)));
         mark_subtree_removal_recur(dt, left);
         mark_subtree_removal_recur(dt, right);
         dt.feature_indices(me) = dt.NODE_NON_EXISTING;
@@ -724,6 +744,7 @@ prune_and_cplist::run(AnyType &args){
     std::vector<double> node_complexities(dt.feature_indices.size(), alpha);
 
     prune_tree(dt, 0, alpha, root_risk, node_complexities);
+
     // Get the new tree_depth after pruning
     //  Note: externally, tree_depth starts from 0 but DecisionTree assumes
     //  tree_depth starts from 1

http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/decision_tree.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/decision_tree.hpp b/src/modules/recursive_partitioning/decision_tree.hpp
index 3a3965f..ae62bfa 100644
--- a/src/modules/recursive_partitioning/decision_tree.hpp
+++ b/src/modules/recursive_partitioning/decision_tree.hpp
@@ -14,6 +14,7 @@ DECLARE_UDF(recursive_partitioning, compute_surr_stats_transition)
 DECLARE_UDF(recursive_partitioning, dt_surr_apply)
 
 DECLARE_UDF(recursive_partitioning, print_decision_tree)
+DECLARE_UDF(recursive_partitioning, get_variable_importance)
 DECLARE_UDF(recursive_partitioning, predict_dt_response)
 DECLARE_UDF(recursive_partitioning, predict_dt_prob)
 
@@ -24,5 +25,5 @@ DECLARE_UDF(recursive_partitioning, display_text_tree)
 DECLARE_UDF(recursive_partitioning, convert_to_rpart_format)
 DECLARE_UDF(recursive_partitioning, get_split_thresholds)
 DECLARE_UDF(recursive_partitioning, prune_and_cplist)
-    
+
 DECLARE_UDF(recursive_partitioning, convert_to_random_forest_format)

http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
index 04fde7e..57c3025 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
@@ -28,6 +28,7 @@ from utilities.validate_args import columns_exist_in_table
 from utilities.validate_args import is_var_valid
 from utilities.validate_args import unquote_ident
 from utilities.utilities import _assert
+from utilities.utilities import _array_to_string
 from utilities.utilities import extract_keyvalue_params
 from utilities.utilities import unique_string
 from utilities.utilities import add_postfix
@@ -187,7 +188,6 @@ def _classify_features(feature_to_type, features):
                             if feature_to_type[c] in boolean_types]
 
     # Integer types are not considered continuous
-    con_types = ['real', 'double precision', 'numeric', 'decimal']
     con_features = [c for c in features
                     if is_psql_numeric_type(feature_to_type[c], exclude=int_types)]
 
@@ -331,6 +331,11 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion,
             if 'cp_list' in pruned_tree:
                 tree['cp_list'] = pruned_tree['cp_list']
 
+        importance_vectors = _compute_var_importance(
+            schema_madlib, tree,
+            len(cat_features), len(con_features))
+        tree.update(**importance_vectors)
+
     return tree_states, bins, dep_list, n_rows
 # -------------------------------------------------------------------------
 
@@ -394,7 +399,8 @@ def _build_tree(schema_madlib, is_classification, split_criterion,
         # we need to let it know that right now it is called from CV.
         use_existing_tables = table_exists(output_table_name)  # create tables if it does not exist
         running_cv = True  # flag to indicate that cv fold ID needs to be included
-        _create_output_tables(**locals())
+        _create_output_tables(list_of_features_to_exclude='',
+                              **locals())
 # ------------------------------------------------------------------------------
 
 
@@ -483,6 +489,13 @@ def tree_train(schema_madlib, training_table_name, output_table_name,
         cat_features, ordered_cat_features, boolean_cats, con_features = \
             _classify_features(all_cols_types, features)
 
+        # assert that the continuous and categorical features together
+        # cover all features
+        invalid_features = set(features) - (set(cat_features) | set(con_features))
+        _assert(not invalid_features,
+                "DT error: Some of the features are invalid ({0})".
+                format(invalid_features))
+
         # get all rows
         n_all_rows = plpy.execute("SELECT count(*) FROM {source_table}".
                                   format(source_table=training_table_name)
@@ -513,6 +526,7 @@ def tree_train(schema_madlib, training_table_name, output_table_name,
 def _create_output_tables(schema_madlib, training_table_name, output_table_name,
                           tree_states, bins, split_criterion,
                           id_col_name, dependent_variable, list_of_features,
+                          list_of_features_to_exclude,
                           is_classification, n_all_rows, n_rows, dep_list, cp,
                           all_cols_types, grouping_cols=None,
                           use_existing_tables=False, running_cv=False,
@@ -532,7 +546,7 @@ def _create_output_tables(schema_madlib, training_table_name, output_table_name,
     _create_summary_table(
         schema_madlib, split_criterion, training_table_name,
         output_table_name, id_col_name, bins['cat_features'], bins['con_features'],
-        dependent_variable, list_of_features,
+        dependent_variable, list_of_features, list_of_features_to_exclude,
         failed_groups, is_classification, n_all_rows,
         n_rows, dep_list, all_cols_types, cp, grouping_cols, 1,
         use_existing_tables, n_folds, null_proxy)
@@ -783,35 +797,45 @@ def _create_result_table(schema_madlib, tree_state,
         header = "create table " + output_table_name + " as "
     depth = (tree_state['pruned_depth'] if 'pruned_depth' in tree_state
              else tree_state['tree_depth'])
+
     if len(cat_features) > 0:
         sql = header + """
                 SELECT
-                    {cp} as pruning_cp,
-                    $1 as tree,
-                    $2 as cat_levels_in_text,
-                    $3 as cat_n_levels,
-                    {depth} as tree_depth
+                    {cp} AS pruning_cp,
+                    $1 AS tree,
+                    $2 AS cat_levels_in_text,
+                    $3 AS cat_n_levels,
+                    $4 AS impurity_var_importance,
+                    {depth} AS tree_depth
                     {fold}
             """.format(depth=depth,
                        cp=tree_state['cp'],
                        fold=fold)
         sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib),
-                                      'text[]', 'integer[]'])
-        plpy.execute(sql_plan, [tree_state['tree_state'], cat_origin, cat_n])
+                                      'text[]',
+                                      'integer[]',
+                                      'double precision[]'])
+        plpy.execute(sql_plan, [tree_state['tree_state'],
+                                cat_origin,
+                                cat_n,
+                                tree_state['impurity_var_importance']])
     else:
         sql = header + """
                 SELECT
-                    {cp} as pruning_cp,
-                    $1 as tree,
-                    NULL::text[] as cat_levels_in_text,
-                    NULL::integer[] as cat_n_levels,
-                    {depth} as tree_depth
+                    {cp} AS pruning_cp,
+                    $1 AS tree,
+                    NULL::text[] AS cat_levels_in_text,
+                    NULL::integer[] AS cat_n_levels,
+                    $2 AS impurity_var_importance,
+                    {depth} AS tree_depth
                     {fold}
             """.format(depth=depth,
                        cp=tree_state['cp'],
                        fold=fold)
-        sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib)])
-        plpy.execute(sql_plan, [tree_state['tree_state']])
+        sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib),
+                                      'double precision[]'])
+        plpy.execute(sql_plan, [tree_state['tree_state'],
+                                tree_state['impurity_var_importance']])
 
 # ------------------------------------------------------------
 
@@ -1007,8 +1031,8 @@ def _get_bins_grps(
         #   cat_origin = [0, 1, 4, 6, 8, 0, 1, 4, 6, 0, 1, 4, 6, 8]
         #   grp_key_cat = ['3', '4', '5']
         cat_items_list = [rows[col]
-                            for grp_key, rows in grp_to_col_to_levels
-                            for col in cat_features if col in rows]
+                          for grp_key, rows in grp_to_col_to_levels
+                          for col in cat_features if col in rows]
         cat_n = [len(i) for i in cat_items_list]
         cat_origin = [item for sublist in cat_items_list for item in sublist]
         grp_key_cat = [item[0] for item in grp_to_col_to_levels]
@@ -1097,28 +1121,21 @@ def _one_step(schema_madlib, training_table_name, cat_features,
                                                          "$3", "$2",
                                                          null_proxy)
 
-    # The arguments of the aggregate (in the same order):
-    # 1. current tree state, madlib.bytea8
-    # 2. categorical features (integer format) in a single array
-    # 3. continuous features in a single array
-    # 4. weight value
-    # 5. categorical sorted levels (integer format) in a combined array
-    # 6. continuous splits
-    # 7. number of dependent levels
     train_sql = """
         SELECT (result).* from (
             SELECT
-                {schema_madlib}._dt_apply($1,
+                {schema_madlib}._dt_apply(
+                    $1,
                     {schema_madlib}._compute_leaf_stats(
-                        $1,
-                        {cat_features_str},
-                        {con_features_str},
+                        $1,                  -- current tree state, madlib.bytea8
+                        {cat_features_str},  -- categorical features in an array
+                        {con_features_str},  -- continuous features in an array
                         {dep_var},
-                        {weights},
-                        $2,
-                        $4,
-                        {dep_n_levels}::smallint,
-                        {subsample}::boolean
+                        {weights},           -- weight value
+                        $2,                  -- categorical sorted levels in a combined array
+                        $4,                  -- continuous splits
+                        {dep_n_levels}::smallint, -- number of dependent levels
+                        {subsample}::boolean  -- should we use a subsample of data
                     ),
                     $4,
                     {min_split}::smallint,
@@ -1364,12 +1381,14 @@ def _create_grp_result_table(
         running_cv=False, k=0):
     """ Create the output table for grouping case.
     """
+
     grp_key = unique_string()
     tree_state = unique_string()
     tree_depth = unique_string()
     cp_col = unique_string()
     cat_levels_in_text = unique_string()
     cat_n_levels = unique_string()
+    impurity_var_importance = unique_string()
     cat_levels_val = cat_levels_in_text if cat_features else "NULL::TEXT[]"
     cat_n_levels_val = cat_n_levels if cat_features else "NULL::INTEGER[]"
     grouping_array_str = bins['grouping_array_str']
@@ -1383,25 +1402,28 @@ def _create_grp_result_table(
     sql = header + """
             SELECT
                 {grouping_cols},
-                {tree_state} as tree,
-                {cat_levels_val} as cat_levels_in_text,
-                {cat_n_levels_val} as cat_n_levels,
-                {tree_depth} as tree_depth,
-                {cp_col} as pruning_cp
+                {tree_state} AS tree,
+                {cat_levels_val} AS cat_levels_in_text,
+                {cat_n_levels_val} AS cat_n_levels,
+                string_to_array(trim(both '{{}}' FROM {impurity_var_importance}),
+                                ',')::double precision[] AS impurity_var_importance,
+                {tree_depth} AS tree_depth,
+                {cp_col} AS pruning_cp
                 {fold}
             FROM (
                 SELECT
                     {grouping_cols},
-                    {grouping_array_str} as {grp_key}
+                    {grouping_array_str} AS {grp_key}
                 FROM {training_table_name}
                 group by {grouping_cols}
             ) s1
             JOIN (
                 SELECT
-                    unnest($1) as {grp_key},
-                    unnest($2) as {tree_state},
-                    unnest($3) as {tree_depth},
-                    unnest($4) as {cp_col}
+                    unnest($1) AS {grp_key},
+                    unnest($2) AS {tree_state},
+                    unnest($3) AS {tree_depth},
+                    unnest($4) AS {cp_col},
+                    unnest($5) AS {impurity_var_importance}
             ) s2
             USING ({grp_key})
             """
@@ -1413,41 +1435,30 @@ def _create_grp_result_table(
                         cat_n_levels as {cat_n_levels},
                         cat_levels_in_text as {cat_levels_in_text}
                     FROM
-                        {schema_madlib}._gen_cat_levels_set($5, $6, $7, $8)
-
+                        {schema_madlib}._gen_cat_levels_set($6, $7, $8, $9)
                 ) s3
                 USING ({grp_key})
             """
     sql = sql.format(**locals())
+    prepare_list = ['text[]',
+                    '{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib),
+                    'integer[]', 'double precision[]', 'text[]']
+    execute_list = [
+        [t['grp_key'] for t in tree_states],
+        [t['tree_state'] for t in tree_states],
+        [t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth']
+         for t in tree_states],
+        [t['cp'] for t in tree_states],
+        [_array_to_string(t['impurity_var_importance']) for t in tree_states]]
     if cat_features:
-        sql_plan = plpy.prepare(
-            sql, ['text[]',
-                  '{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib),
-                  'integer[]', 'double precision[]',
-                  'text[]', 'integer[]', 'integer', 'text[]'])
-        plpy.execute(sql_plan, [
-            [t['grp_key'] for t in tree_states],
-            [t['tree_state'] for t in tree_states],
-            [t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth']
-             for t in tree_states],
-            [t['cp'] for t in tree_states],
+        prepare_list += ['text[]', 'integer[]', 'integer', 'text[]']
+        execute_list += [
             bins['grp_key_cat'],
             bins['cat_n'],
             len(cat_features),
-            bins['cat_origin']])
-    else:
-        sql_plan = plpy.prepare(sql, [
-            'text[]',
-            '{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib),
-            'integer[]',
-            'double precision[]'])
-        plpy.execute(sql_plan, [
-            [t['grp_key'] for t in tree_states],
-            [t['tree_state'] for t in tree_states],
-            [t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth']
-             for t in tree_states],
-            [t['cp'] for t in tree_states]
-        ])
+            bins['cat_origin']]
+    sql_plan = plpy.prepare(sql, prepare_list)
+    plpy.execute(sql_plan, execute_list)
 # ------------------------------------------------------------
 
 
@@ -1476,6 +1487,7 @@ def _create_summary_table(
         schema_madlib, split_criterion,
         training_table_name, output_table_name, id_col_name,
         cat_features, con_features, dependent_variable, list_of_features,
+        list_of_features_to_exclude,
         num_failed_groups, is_classification, n_all_rows, n_rows,
         dep_list, all_cols_types, cp, grouping_cols=None, n_groups=1,
         use_existing_tables=False, n_folds=0, null_proxy=None):
@@ -1491,6 +1503,7 @@ def _create_summary_table(
         dep_list_str = "NULL"
     indep_type = ', '.join(all_cols_types[c] for c in cat_features + con_features)
     dep_type = _get_dep_type(training_table_name, dependent_variable)
+    independent_varnames = ','.join(cat_features + con_features)
     cat_features_str = ','.join(cat_features)
     con_features_str = ','.join(con_features)
     if grouping_cols:
@@ -1517,8 +1530,10 @@ def _create_summary_table(
                 '{training_table_name}'::text  AS source_table,
                 '{output_table_name}'::text    AS model_table,
                 '{id_col_name}'::text          AS id_col_name,
+                '{list_of_features}'::text     AS list_of_features,
+                '{list_of_features_to_exclude}'::text     AS list_of_features_to_exclude,
                 '{dependent_variable}'::text   AS dependent_varname,
-                '{list_of_features}'::text     AS independent_varnames,
+                '{independent_varnames}'::text     AS independent_varnames,
                 '{cat_features_str}'::text     AS cat_features,
                 '{con_features_str}'::text     AS con_features,
                 {grouping_cols_str}::text      AS grouping_cols,
@@ -1899,6 +1914,31 @@ def _prune_and_cplist(schema_madlib, tree, cp, compute_cp_list=False):
 # -------------------------------------------------------------------------
 
 
+def _compute_var_importance(schema_madlib, tree,
+                            n_cat_features, n_con_features):
+    """ Compute variable importance for categorical and continuous features
+
+        Args:
+            @param schema_madlib: str, MADlib schema name
+            @param tree: Tree data to prune
+            @param n_cat_features: int, Number of categorical features
+            @param n_con_features: int, Number of continuous features
+
+        Returns:
+            Dictionary containing following keys:
+                impurity_var_importance: Array of importance values
+    """
+    var_imp_sql = """
+        SELECT {schema_madlib}._get_var_importance(
+            $1,    -- trained decision tree
+            {n_cat_features},
+            {n_con_features}) AS impurity_var_importance
+        """.format(**locals())
+    var_imp_plan = plpy.prepare(var_imp_sql, [schema_madlib + '.bytea8'])
+    return plpy.execute(var_imp_plan, [tree['tree_state']])[0]
+# ------------------------------------------------------------------------------
+
+
 def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_name,
                id_col_name, dependent_variable,
                list_of_features, list_of_features_to_exclude,
@@ -2056,6 +2096,10 @@ def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_nam
                 tree['pruned_depth'] = pruned_tree['tree_depth']
             else:
                 tree['pruned_depth'] = 0
+        importance_vectors = _compute_var_importance(
+            schema_madlib, tree,
+            len(cat_features), len(con_features))
+        tree.update(**importance_vectors)
 
     plpy.execute("DROP TABLE {group_to_param_list_table}".format(**locals()))
 # ------------------------------------------------------------
@@ -2157,13 +2201,14 @@ def _tree_train_using_bins(
     tree_state = plpy.execute(
         """
         SELECT {schema_madlib}._initialize_decision_tree(
-            {is_regression_tree},
-            '{split_criterion}'::text,
-            {dep_n_levels}::smallint,
-            {max_n_surr}::smallint)
-            as tree_state, False as finished
+                {is_regression_tree},
+                '{split_criterion}'::text,
+                {dep_n_levels}::smallint,
+                {max_n_surr}::smallint
+                ) AS tree_state,
+            FALSE as finished
         """.format(schema_madlib=schema_madlib,
-                   is_regression_tree=not is_classification,
+                   is_regression_tree=(not is_classification),
                    split_criterion=split_criterion,
                    dep_n_levels=dep_n_levels,
                    max_n_surr=max_n_surr))[0]
@@ -2203,9 +2248,7 @@ def _tree_train_grps_using_bins(
                 {is_regression_tree},
                 '{split_criterion}'::text,
                 {dep_n_levels}::smallint,
-                {max_n_surr}::smallint
-                )
-            AS tree_state,
+                {max_n_surr}::smallint) AS tree_state,
             0 AS finished
         """.format(schema_madlib=schema_madlib,
                    is_regression_tree=not is_classification,

http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
index 8e69d9b..9d5a5ef 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
@@ -279,6 +279,18 @@ tree_train(
         for <em>weather_outlook</em> and 2 levels
         <em>windy</em>.</td>
       </tr>
+      <tr>
+      <th>impurity_var_importance</th>
+      <td>DOUBLE PRECISION[]. Impurity importance (also referred to as Gini
+      importance) of each variable. The order of the variables is the same as
+      that of 'independent_varnames' column in the summary table (see below).
+
+      The impurity importance of any feature is the decrease in impurity by a
+      node containing the feature as a primary split, summed over the whole
+      tree. If surrogates are used, then the importance value includes the
+      impurity decrease scaled by the adjusted surrogate agreement.
+      </td>
+      </tr>
 
       <tr>
       <th>tree_depth</th>
@@ -325,13 +337,24 @@ tree_train(
     </tr>
 
     <tr>
+    <th>list_of_features</th>
+    <td>TEXT. The list_of_features inputed to the 'tree_train' procedure.</td>
+    </tr>
+
+    <tr>
+    <th>list_of_features_to_exclude</th>
+    <td>TEXT. The list_of_features_to_exclude inputed to the 'tree_train' procedure.</td>
+    </tr>
+
+    <tr>
     <th>dependent_varname</th>
     <td>TEXT. The dependent variable.</td>
     </tr>
 
     <tr>
     <th>independent_varnames</th>
-    <td>TEXT. The independent variables.</td>
+    <td>TEXT. The independent variables. These are the features used in the
+    training of the decision tree.</td>
     </tr>
 
     <tr>
@@ -630,12 +653,17 @@ SELECT madlib.tree_train('dt_golf',         -- source table
 </pre>
 View the output table (excluding the tree which is in binary format):
 <pre class="example">
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+\\x on
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
 </pre>
 <pre class="result">
- pruning_cp |        cat_levels_in_text        | cat_n_levels | tree_depth
-------------+----------------------------------+--------------+------------
-          0 | {overcast,rain,sunny,False,True} | {3,2}        |          5
+-[ RECORD 1 ]-----------+--------------------------------------
+pruning_cp              | 0
+cat_levels_in_text      | {overcast,rain,sunny,False,True}
+cat_n_levels            | {3,2}
+impurity_var_importance | {10.6171201061712,0,89.3828798938288}
+tree_depth              | 5
+
 </pre>
 View the summary table:
 <pre class="example">
@@ -643,27 +671,29 @@ View the summary table:
 SELECT * FROM train_output_summary;
 </pre>
 <pre class="result">
--[ RECORD 1 ]---------+--------------------------------
-method                | tree_train
-is_classification     | t
-source_table          | dt_golf
-model_table           | train_output
-id_col_name           | id
-dependent_varname     | class
-independent_varnames  | "OUTLOOK", temperature, windy
-cat_features          | "OUTLOOK",windy
-con_features          | temperature
-grouping_cols         |
-num_all_groups        | 1
-num_failed_groups     | 0
-total_rows_processed  | 14
-total_rows_skipped    | 0
-dependent_var_levels  | "Don't Play","Play"
-dependent_var_type    | text
-input_cp              | 0
-independent_var_types | text, boolean, double precision
-n_folds               | 0
-null_proxy            |
+-[ RECORD 1 ]---------------+--------------------------------
+method                      | tree_train
+is_classification           | t
+source_table                | dt_golf
+model_table                 | train_output
+id_col_name                 | id
+list_of_features            | "OUTLOOK", temperature, windy
+list_of_features_to_exclude | None
+dependent_varname           | class
+independent_varnames        | "OUTLOOK",windy,temperature
+cat_features                | "OUTLOOK",windy
+con_features                | temperature
+grouping_cols               |
+num_all_groups              | 1
+num_failed_groups           | 0
+total_rows_processed        | 14
+total_rows_skipped          | 0
+dependent_var_levels        | "Don't Play","Play"
+dependent_var_type          | text
+input_cp                    | 0
+independent_var_types       | text, boolean, double precision
+n_folds                     | 0
+null_proxy                  |
 </pre>
 
 -# Predict output categories.  For the purpose
@@ -890,16 +920,20 @@ SELECT madlib.tree_train('dt_golf',         -- source table
                          1,                 -- min bucket
                          10                 -- number of bins per continuous variable
                          );
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
 </pre>
 View the output table (excluding the tree which is in binary format):
 <pre class="example">
 SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
 </pre>
 <pre class="result">
- pruning_cp |               cat_levels_in_text               | cat_n_levels | tree_depth
-------------+------------------------------------------------+--------------+------------
-          0 | {medium,none,high,low,unhealthy,good,moderate} | {4,3}        |          3
+-[ RECORD 1 ]-----------+-----------------------------------------------------
+pruning_cp              | 0
+cat_levels_in_text      | {medium,none,high,low,unhealthy,good,moderate}
+cat_n_levels            | {4,3}
+impurity_var_importance | {0,40.2340084993653,5.6791213643137,54.086870136321}
+tree_depth              | 3
+
 </pre>
 The first 4 levels correspond to cloud ceiling and the next 3 levels
 correspond to air quality.
@@ -1032,12 +1066,15 @@ SELECT madlib.tree_train('mt_cars',         -- source table
 View the output table (excluding the tree which is in binary format)
 which shows ordering of levels of categorical variables 'vs' and 'cyl':
 <pre class="example">
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
 </pre>
 <pre class="result">
- pruning_cp | cat_levels_in_text | cat_n_levels | tree_depth
-------------+--------------------+--------------+------------
-          0 | {0,1,4,6,8}        | {2,3}        |          4
+pruning_cp              | 0
+cat_levels_in_text      | {0,1,4,6,8}
+cat_n_levels            | {2,3}
+impurity_var_importance | {0,51.8593201959496,10.976977929129,5.31897402755374,31.8447278473677}
+tree_depth              | 4
+
 </pre>
 View the summary table:
 <pre class="example">
@@ -1045,27 +1082,30 @@ View the summary table:
 SELECT * FROM train_output_summary;
 </pre>
 <pre class="result">
--[ RECORD 1 ]---------+-----------------------------------------------------------------------
-method                | tree_train
-is_classification     | f
-source_table          | mt_cars
-model_table           | train_output
-id_col_name           | id
-dependent_varname     | mpg
-independent_varnames  | *
-cat_features          | vs,cyl
-con_features          | disp,qsec,wt
-grouping_cols         |
-num_all_groups        | 1
-num_failed_groups     | 0
-total_rows_processed  | 32
-total_rows_skipped    | 0
-dependent_var_levels  |
-dependent_var_type    | double precision
-input_cp              | 0
-independent_var_types | integer, integer, double precision, double precision, double precision
-n_folds               | 0
-null_proxy            |
+-[ RECORD 1 ]---------------+-----------------------------------------------------------------------
+method                      | tree_train
+is_classification           | f
+source_table                | mt_cars
+model_table                 | train_output
+id_col_name                 | id
+list_of_features            | *
+list_of_features_to_exclude | id, hp, drat, am, gear, carb
+dependent_varname           | mpg
+independent_varnames        | vs,cyl,disp,qsec,wt
+cat_features                | vs,cyl
+con_features                | disp,qsec,wt
+grouping_cols               |
+num_all_groups              | 1
+num_failed_groups           | 0
+total_rows_processed        | 32
+total_rows_skipped          | 0
+dependent_var_levels        |
+dependent_var_type          | double precision
+input_cp                    | 0
+independent_var_types       | integer, integer, double precision, double precision, double precision
+n_folds                     | 0
+null_proxy                  |
+
 </pre>
 
 -# Predict regression output for the same data and compare with original:
@@ -1212,12 +1252,16 @@ View the output table (excluding the tree which is in binary format).
 The input cp value was 0 (default) and the best 'pruning_cp' value
 turns out to be 0 as well in this small example:
 <pre class="example">
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
 </pre>
 <pre class="result">
- pruning_cp | cat_levels_in_text | cat_n_levels | tree_depth
-------------+--------------------+--------------+------------
-          0 | {0,1,4,6,8}        | {2,3}        |          4
+-[ RECORD 1 ]-----------+-----------------------------------------------------------------------
+pruning_cp              | 0
+cat_levels_in_text      | {0,1,4,6,8}
+cat_n_levels            | {2,3}
+impurity_var_importance | {0,51.8593201959496,10.976977929129,5.31897402755374,31.8447278473677}
+tree_depth              | 4
+
 </pre>
 The cp values tested and average error and standard deviation are:
 <pre class="example">
@@ -1288,27 +1332,29 @@ View the summary table:
 SELECT * FROM train_output_summary;
 </pre>
 <pre class='result'>
--[ RECORD 1 ]---------+-----------------------
-method                | tree_train
-is_classification     | t
-source_table          | null_handling_example
-model_table           | train_output
-id_col_name           | id
-dependent_varname     | response
-independent_varnames  | country, weather, city
-cat_features          | country,weather,city
-con_features          |
-grouping_cols         |
-num_all_groups        | 1
-num_failed_groups     | 0
-total_rows_processed  | 4
-total_rows_skipped    | 0
-dependent_var_levels  | "a","b","c","d"
-dependent_var_type    | text
-input_cp              | 0
-independent_var_types | text, text, text
-n_folds               | 0
-null_proxy            | __NULL__
+-[ RECORD 1 ]---------------+-----------------------
+method                      | tree_train
+is_classification           | t
+source_table                | null_handling_example
+model_table                 | train_output
+id_col_name                 | id
+list_of_features            | country, weather, city
+list_of_features_to_exclude | None
+dependent_varname           | response
+independent_varnames        | country,weather,city
+cat_features                | country,weather,city
+con_features                |
+grouping_cols               | [NULL]
+num_all_groups              | 1
+num_failed_groups           | 0
+total_rows_processed        | 4
+total_rows_skipped          | 0
+dependent_var_levels        | "a","b","c","d"
+dependent_var_type          | text
+input_cp                    | 0
+independent_var_types       | text, text, text
+n_folds                     | 0
+null_proxy                  | __NULL__
 </pre>
 
 -# Predict for data not previously seen by assuming NULL
@@ -1697,8 +1743,8 @@ CREATE AGGREGATE MADLIB_SCHEMA._compute_leaf_stats(
 DROP TYPE IF EXISTS MADLIB_SCHEMA._tree_result_type CASCADE;
 CREATE TYPE MADLIB_SCHEMA._tree_result_type AS (
     tree_state      MADLIB_SCHEMA.BYTEA8,
-    finished        smallint, -- 0 running, 1 finished, 2 failed
-    tree_depth      smallint  -- depth of the returned tree (0 = root node)
+    finished        smallint,  -- 0 running, 1 finished, 2 failed
+    tree_depth      smallint   -- depth of the returned tree (0 = root node)
 );
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._dt_apply(
@@ -1785,6 +1831,16 @@ LANGUAGE C IMMUTABLE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 -------------------------------------------------------------------------
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._get_var_importance(
+    tree            MADLIB_SCHEMA.BYTEA8,
+    n_cat_features  INTEGER,
+    n_con_features  INTEGER
+) RETURNS DOUBLE PRECISION[] AS
+    'MODULE_PATHNAME', 'get_variable_importance'
+LANGUAGE C IMMUTABLE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
+-------------------------------------------------------------------------
+
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._predict_dt_response(
     tree            MADLIB_SCHEMA.BYTEA8,
     cat_features    INTEGER[],

http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
index ba0c75e..dee3e32 100644
--- a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
@@ -287,6 +287,7 @@ SELECT tree_train('dt_golf'::text,         -- source table
 
 SELECT _print_decision_tree(tree) from train_output;
 SELECT tree_display('train_output', False);
+SELECT impurity_var_importance FROM train_output;
 SELECT * FROM train_output_cv;
 -------------------------------------------------------------------------
 
@@ -309,6 +310,8 @@ SELECT tree_train('dt_golf'::text,         -- source table
                          );
 
 SELECT _print_decision_tree(tree) from train_output;
+SELECT tree_display('train_output', FALSE);
+SELECT * FROM train_output;
 SELECT tree_display('train_output', False);
 
 
@@ -354,6 +357,8 @@ SELECT tree_train('dt_golf'::text,         -- source table
 
 SELECT _print_decision_tree(tree) from train_output;
 SELECT tree_display('train_output', False);
+SELECT tree_surr_display('train_output');
+SELECT * FROM train_output;
 SELECT tree_predict('train_output', 'dt_golf', 'predict_output');
 \x off
 SELECT *
@@ -367,6 +372,31 @@ select * from train_output;
 select * from train_output_summary;
 -------------------------------------------------------------------------
 
+-- variable importance check
+DROP TABLE IF EXISTS train_output, train_output_summary, predict_output;
+SELECT tree_train('dt_golf'::text,         -- source table
+                  'train_output'::text,    -- output model table
+                  'id'::text,              -- id column
+                  'temperature::double precision'::text,           -- response
+                  '"OUTLOOK", temperature'::text,   -- features
+                  NULL::text,        -- exclude columns
+                  'mse'::text,       -- split criterion
+                  NULL::text,      -- grouping col
+                  NULL::text,        -- no weights
+                  NULL::integer,     -- max depth
+                  6::integer,        -- min split
+                  2::integer,        -- min bucket
+                  3::integer,        -- number of bins per continuous variable
+                  'cp=0.01',
+                  ''
+                  );
+
+SELECT impurity_var_importance FROM train_output;
+SELECT assert(impurity_var_importance[2] > 90,
+              'Variable importance not valid for extreme case')
+FROM train_output;
+-------------------------------------------------------------------------
+
 drop table if exists group_cp;
 create table group_cp(class TEXT,
                       explore_value  DOUBLE PRECISION);
@@ -405,9 +435,6 @@ select __build_tree(
     );
 
 select tree_display('train_output', FALSE);
-
-\d train_output_summary
-\x on
 select * from train_output;
 select * from train_output_summary;