You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2018/06/26 08:58:09 UTC
madlib git commit: DT: Add impurity variable importance
Repository: madlib
Updated Branches:
refs/heads/master 69ab239b8 -> b8031a03f
DT: Add impurity variable importance
JIRA: MADLIB-1205
Brieman et. al. [1] describe a "gini importance" measure that can be
computed for a single decision tree. This measure is the impurity
decrease produced by any given feature in a node, accumulated over the
whole tree. Surrogates can also be added to this by scaling the impurity
decrease with the adjusted surrogate agreement.
This commit adds this importance measure for all the impurity functions
(hence the term impurity importance instead of gini importance).
[1] https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#giniimp
Closes #277
Co-authored-by: Nandish Jayaram <nj...@apache.org>
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/b8031a03
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/b8031a03
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/b8031a03
Branch: refs/heads/master
Commit: b8031a03f0bdb4481b3470831a07fc167b5d46ec
Parents: 69ab239
Author: Rahul Iyer <ri...@apache.org>
Authored: Tue May 29 18:09:18 2018 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Tue Jun 26 01:56:21 2018 -0700
----------------------------------------------------------------------
src/modules/recursive_partitioning/DT_impl.hpp | 91 +++++++-
src/modules/recursive_partitioning/DT_proto.hpp | 23 ++
.../recursive_partitioning/decision_tree.cpp | 35 ++-
.../recursive_partitioning/decision_tree.hpp | 3 +-
.../recursive_partitioning/decision_tree.py_in | 209 +++++++++++-------
.../recursive_partitioning/decision_tree.sql_in | 220 ++++++++++++-------
.../test/decision_tree.sql_in | 33 ++-
7 files changed, 428 insertions(+), 186 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/DT_impl.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_impl.hpp b/src/modules/recursive_partitioning/DT_impl.hpp
index 69d0817..d9477c9 100644
--- a/src/modules/recursive_partitioning/DT_impl.hpp
+++ b/src/modules/recursive_partitioning/DT_impl.hpp
@@ -116,13 +116,11 @@ template <class Container>
inline
void
DecisionTree<Container>::bind(ByteStream_type& inStream) {
-
inStream >> tree_depth
>> n_y_labels
>> max_n_surr
>> is_regression
>> impurity_type;
-
size_t n_nodes = 0;
size_t n_labels = 0;
size_t max_surrogates = 0;
@@ -406,8 +404,8 @@ DecisionTree<Container>::impurityGain(const ColumnVector &combined_stats,
double false_count = statWeightedCount(combined_stats.segment(stats_per_split, stats_per_split));
double total_count = true_count + false_count;
- if (true_count == 0 || false_count == 0) {
- // no gain if all fall into one side
+ if (total_count == 0 || true_count == 0 || false_count == 0) {
+ // no gain if no tuples incoming or if all fall into one side
return 0.;
}
double true_weight = true_count / total_count;
@@ -574,7 +572,6 @@ DecisionTree<Container>::expand(const Accumulator &state,
max_stats.segment(0, sps), // true_stats
max_stats.segment(sps, sps) // false_stats
);
-
} else {
feature_indices(current) = FINISHED_LEAF;
}
@@ -619,11 +616,11 @@ DecisionTree<Container>::pickSurrogates(
// we use the *_agg_matrix to add every two alternate columns of the
// *_stats matrix to create a forward and reverse agreement metric
// for each split.
- // eg. in cat_stats,
- // we add columns 1 and 3 to get the <= split agreement for 1st cat split.
- // we add columns 2 and 4 to get the > split agreement for 1st cat split.
- // we add columns 5 and 7 to get the <= split agreement for 2nd cat split.
- // we add columns 6 and 8 to get the > split agreement for 2nd cat split.
+ // eg. For cat_stats,
+ // add columns 1 and 3 to get the <= split agreement for 1st cat split.
+ // add columns 2 and 4 to get the > split agreement for 1st cat split.
+ // add columns 5 and 7 to get the <= split agreement for 2nd cat split.
+ // add columns 6 and 8 to get the > split agreement for 2nd cat split.
ColumnVector fwd_agg_vec(4);
fwd_agg_vec << 1, 0, 1, 0;
ColumnVector rev_agg_vec(4);
@@ -1460,6 +1457,80 @@ DecisionTree<Container>::encodeIndex(const int &feature_index,
template <class Container>
inline
+void
+DecisionTree<Container>::computeVariableImportance(
+ ColumnVector &cat_var_importance,
+ ColumnVector &con_var_importance){
+
+ // stats_per_split
+ uint16_t sps = is_regression ? REGRESS_N_STATS :
+ static_cast<uint16_t>(n_y_labels + 1);
+
+ // loop through for each internal node and check the primary split and any
+ // surrogate splits
+ for (Index node_index = 0;
+ node_index < feature_indices.size() / 2;
+ node_index++){
+ if (isInternalNode(node_index)){
+ ColumnVector combined_stats(sps * 2);
+ combined_stats << predictions.row(trueChild(node_index)).transpose(),
+ predictions.row(falseChild(node_index)).transpose();
+ double split_gain = impurityGain(combined_stats, sps);
+
+ // importance = impurity gain from split +
+ // impurity gain * adjusted agreement from
+ // surrogate split
+
+ // primary split contribution to importance
+ Index feat_index = feature_indices(node_index);
+ if(is_categorical(node_index)) {
+ assert(feat_index < cat_var_importance.size());
+ cat_var_importance(feat_index) += split_gain;
+ } else {
+ assert(feat_index < con_var_importance.size());
+ con_var_importance(feat_index) += split_gain;
+ }
+
+ // surrogate contribution to importance
+ if (max_n_surr > 0){
+ for (Index surr_count=0; surr_count < max_n_surr; surr_count++){
+ Index surr_lookup_index = node_index * max_n_surr + surr_count;
+
+ // surr_status == 0 implies non-existing surrogate
+ if (surr_status(surr_lookup_index) == 0)
+ break;
+
+ uint64_t total_count =
+ statCount(predictions.row(trueChild(node_index))) +
+ statCount(predictions.row(falseChild(node_index)));
+
+ // Adjusted agreement is defined as how much better does
+ // the surrogate do compared to majority count. This value
+ // is relative to the number of rows that the majority branch
+ // would not predict correctly (minority count).
+ uint64_t node_count = nodeCount(node_index);
+ uint64_t maj_count = getMajorityCount(node_index);
+ uint64_t min_count = node_count - maj_count;
+
+ double adj_agreement =
+ (surr_agreement(surr_lookup_index) - maj_count) / min_count;
+ Index surr_feat_index = surr_indices(surr_lookup_index);
+ if (std::abs(surr_status(surr_lookup_index)) == 1){
+ cat_var_importance[surr_feat_index] += split_gain *
+ adj_agreement;
+ } else {
+ con_var_importance[surr_feat_index] += split_gain *
+ adj_agreement;
+ }
+ }
+ }
+ }
+ }
+}
+// -------------------------------------------------------------------------
+
+template <class Container>
+inline
string
DecisionTree<Container>::surr_display(
ArrayHandle<text*> &cat_features_str,
http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/DT_proto.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/DT_proto.hpp b/src/modules/recursive_partitioning/DT_proto.hpp
index 1852f69..6980326 100644
--- a/src/modules/recursive_partitioning/DT_proto.hpp
+++ b/src/modules/recursive_partitioning/DT_proto.hpp
@@ -29,6 +29,17 @@ using namespace dbal::eigen_integration;
using std::vector;
using std::string;
+
+// stats_per_split is the number of statistics needed to accumulate for a split.
+// This differs for classification and regression. Value for classification is a
+// function of number of response values and is computed at runtime whereas
+// value for regression is a constant.
+
+// For classification, accumulate the following:
+// num of weighted tuples for each possible response and num of unweighted tuples
+
+// For regression, accumulate 4 values for evaulating a split:
+// weight, weight * response, weight * response^2, num of unweighted rows
const uint16_t REGRESS_N_STATS = 4u;
template <class Container>
@@ -93,6 +104,13 @@ public:
}
Index trueChild(Index current) const { return 2 * current + 1; }
Index falseChild(Index current) const { return 2 * current + 2; }
+ bool isInternalNode(const Index node_index) const {
+ int split_feat_index = feature_indices(node_index);
+ return (split_feat_index != NODE_NON_EXISTING &&
+ split_feat_index != IN_PROCESS_LEAF &&
+ split_feat_index != FINISHED_LEAF);
+ }
+
double impurity(const ColumnVector & stats) const;
double impurityGain(const ColumnVector &combined_stats,
const uint16_t &stats_per_split) const;
@@ -149,6 +167,9 @@ public:
const int &is_categorical,
const int &n_cat_features) const;
+ void computeVariableImportance(ColumnVector& cat_var_importance,
+ ColumnVector& con_var_importance);
+
// attributes
// dimension information
uint16_type tree_depth; // 1 for root-only tree
@@ -172,6 +193,7 @@ public:
IntegerVector_type feature_indices;
// elements are of integer type for categorical
ColumnVector_type feature_thresholds;
+
// used as boolean array, 0 means continuous, otherwise categorical
IntegerVector_type is_categorical;
@@ -207,6 +229,7 @@ public:
Matrix_type predictions; // used as integer if we do classification
};
+
// ------------------------------------------------------------------------
// TreeAccumulator is used for collecting statistics during training the nodes
// for a level of the decision tree. The same accumulator is also used for
http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/decision_tree.cpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/decision_tree.cpp b/src/modules/recursive_partitioning/decision_tree.cpp
index 5579cb5..42ea942 100644
--- a/src/modules/recursive_partitioning/decision_tree.cpp
+++ b/src/modules/recursive_partitioning/decision_tree.cpp
@@ -54,7 +54,6 @@ initialize_decision_tree::run(AnyType & args){
std::string impurity_func_str = args[1].getAs<std::string>();
uint16_t n_y_labels = args[2].getAs<uint16_t>();
uint16_t max_n_surr = args[3].getAs<uint16_t>();
-
if (is_regression_tree)
n_y_labels = REGRESS_N_STATS;
dt.rebind(1u, n_y_labels, max_n_surr, is_regression_tree);
@@ -174,10 +173,7 @@ compute_leaf_stats_transition::run(AnyType & args){
}
}
- // For classification, we store for each split the number of weighted
- // tuples for each possible response value and the number of unweighted
- // tuples landing on that node.
- // For regression, REGRESS_N_STATS determines the number of stats per split
+ // see DT_proto.hpp for explanation on stats_per_split
uint16_t stats_per_split = dt.is_regression ?
REGRESS_N_STATS : static_cast<uint16_t>(n_response_labels + 1);
const bool weights_as_rows = args[9].getAs<bool>();
@@ -491,6 +487,30 @@ print_decision_tree::run(AnyType &args){
}
AnyType
+get_variable_importance::run(AnyType &args){
+ Tree dt = args[0].getAs<ByteString>();
+ const int n_cat_features = args[1].getAs<int>();
+ const int n_con_features = args[2].getAs<int>();
+
+ ColumnVector cat_var_importance = ColumnVector::Zero(n_cat_features);
+ ColumnVector con_var_importance = ColumnVector::Zero(n_con_features);
+ dt.computeVariableImportance(cat_var_importance, con_var_importance);
+
+ // Variable importance is scaled to represent a percentage. Even though
+ // the importance values are split between categorical and continuous, the
+ // percentages are relative to the combined set.
+ ColumnVector combined_var_imp(n_cat_features + n_con_features);
+ combined_var_imp << cat_var_importance, con_var_importance;
+
+ // Avoid divide by zero by adding a small number.
+ double total_var_imp = combined_var_imp.sum();
+ double VAR_IMP_EPSILON = 1e-6;
+ combined_var_imp *= (100.0 / (total_var_imp + VAR_IMP_EPSILON));
+
+ return combined_var_imp;
+}
+
+AnyType
display_text_tree::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
ArrayHandle<text*> cat_feature_names = args[1].getAs<ArrayHandle<text*> >();
@@ -511,8 +531,8 @@ display_text_tree::run(AnyType &args){
void mark_subtree_removal_recur(MutableTree &dt, int me) {
if (me < dt.predictions.rows() &&
dt.feature_indices(me) != dt.NODE_NON_EXISTING) {
- int left = static_cast<int>(dt.trueChild(static_cast<Index>(me))),
- right = static_cast<int>(dt.falseChild(static_cast<Index>(me)));
+ int left = static_cast<int>(dt.trueChild(static_cast<Index>(me)));
+ int right = static_cast<int>(dt.falseChild(static_cast<Index>(me)));
mark_subtree_removal_recur(dt, left);
mark_subtree_removal_recur(dt, right);
dt.feature_indices(me) = dt.NODE_NON_EXISTING;
@@ -724,6 +744,7 @@ prune_and_cplist::run(AnyType &args){
std::vector<double> node_complexities(dt.feature_indices.size(), alpha);
prune_tree(dt, 0, alpha, root_risk, node_complexities);
+
// Get the new tree_depth after pruning
// Note: externally, tree_depth starts from 0 but DecisionTree assumes
// tree_depth starts from 1
http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/modules/recursive_partitioning/decision_tree.hpp
----------------------------------------------------------------------
diff --git a/src/modules/recursive_partitioning/decision_tree.hpp b/src/modules/recursive_partitioning/decision_tree.hpp
index 3a3965f..ae62bfa 100644
--- a/src/modules/recursive_partitioning/decision_tree.hpp
+++ b/src/modules/recursive_partitioning/decision_tree.hpp
@@ -14,6 +14,7 @@ DECLARE_UDF(recursive_partitioning, compute_surr_stats_transition)
DECLARE_UDF(recursive_partitioning, dt_surr_apply)
DECLARE_UDF(recursive_partitioning, print_decision_tree)
+DECLARE_UDF(recursive_partitioning, get_variable_importance)
DECLARE_UDF(recursive_partitioning, predict_dt_response)
DECLARE_UDF(recursive_partitioning, predict_dt_prob)
@@ -24,5 +25,5 @@ DECLARE_UDF(recursive_partitioning, display_text_tree)
DECLARE_UDF(recursive_partitioning, convert_to_rpart_format)
DECLARE_UDF(recursive_partitioning, get_split_thresholds)
DECLARE_UDF(recursive_partitioning, prune_and_cplist)
-
+
DECLARE_UDF(recursive_partitioning, convert_to_random_forest_format)
http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
index 04fde7e..57c3025 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
@@ -28,6 +28,7 @@ from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import is_var_valid
from utilities.validate_args import unquote_ident
from utilities.utilities import _assert
+from utilities.utilities import _array_to_string
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import add_postfix
@@ -187,7 +188,6 @@ def _classify_features(feature_to_type, features):
if feature_to_type[c] in boolean_types]
# Integer types are not considered continuous
- con_types = ['real', 'double precision', 'numeric', 'decimal']
con_features = [c for c in features
if is_psql_numeric_type(feature_to_type[c], exclude=int_types)]
@@ -331,6 +331,11 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion,
if 'cp_list' in pruned_tree:
tree['cp_list'] = pruned_tree['cp_list']
+ importance_vectors = _compute_var_importance(
+ schema_madlib, tree,
+ len(cat_features), len(con_features))
+ tree.update(**importance_vectors)
+
return tree_states, bins, dep_list, n_rows
# -------------------------------------------------------------------------
@@ -394,7 +399,8 @@ def _build_tree(schema_madlib, is_classification, split_criterion,
# we need to let it know that right now it is called from CV.
use_existing_tables = table_exists(output_table_name) # create tables if it does not exist
running_cv = True # flag to indicate that cv fold ID needs to be included
- _create_output_tables(**locals())
+ _create_output_tables(list_of_features_to_exclude='',
+ **locals())
# ------------------------------------------------------------------------------
@@ -483,6 +489,13 @@ def tree_train(schema_madlib, training_table_name, output_table_name,
cat_features, ordered_cat_features, boolean_cats, con_features = \
_classify_features(all_cols_types, features)
+ # assert that the continuous and categorical features together
+ # cover all features
+ invalid_features = set(features) - (set(cat_features) | set(con_features))
+ _assert(not invalid_features,
+ "DT error: Some of the features are invalid ({0})".
+ format(invalid_features))
+
# get all rows
n_all_rows = plpy.execute("SELECT count(*) FROM {source_table}".
format(source_table=training_table_name)
@@ -513,6 +526,7 @@ def tree_train(schema_madlib, training_table_name, output_table_name,
def _create_output_tables(schema_madlib, training_table_name, output_table_name,
tree_states, bins, split_criterion,
id_col_name, dependent_variable, list_of_features,
+ list_of_features_to_exclude,
is_classification, n_all_rows, n_rows, dep_list, cp,
all_cols_types, grouping_cols=None,
use_existing_tables=False, running_cv=False,
@@ -532,7 +546,7 @@ def _create_output_tables(schema_madlib, training_table_name, output_table_name,
_create_summary_table(
schema_madlib, split_criterion, training_table_name,
output_table_name, id_col_name, bins['cat_features'], bins['con_features'],
- dependent_variable, list_of_features,
+ dependent_variable, list_of_features, list_of_features_to_exclude,
failed_groups, is_classification, n_all_rows,
n_rows, dep_list, all_cols_types, cp, grouping_cols, 1,
use_existing_tables, n_folds, null_proxy)
@@ -783,35 +797,45 @@ def _create_result_table(schema_madlib, tree_state,
header = "create table " + output_table_name + " as "
depth = (tree_state['pruned_depth'] if 'pruned_depth' in tree_state
else tree_state['tree_depth'])
+
if len(cat_features) > 0:
sql = header + """
SELECT
- {cp} as pruning_cp,
- $1 as tree,
- $2 as cat_levels_in_text,
- $3 as cat_n_levels,
- {depth} as tree_depth
+ {cp} AS pruning_cp,
+ $1 AS tree,
+ $2 AS cat_levels_in_text,
+ $3 AS cat_n_levels,
+ $4 AS impurity_var_importance,
+ {depth} AS tree_depth
{fold}
""".format(depth=depth,
cp=tree_state['cp'],
fold=fold)
sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib),
- 'text[]', 'integer[]'])
- plpy.execute(sql_plan, [tree_state['tree_state'], cat_origin, cat_n])
+ 'text[]',
+ 'integer[]',
+ 'double precision[]'])
+ plpy.execute(sql_plan, [tree_state['tree_state'],
+ cat_origin,
+ cat_n,
+ tree_state['impurity_var_importance']])
else:
sql = header + """
SELECT
- {cp} as pruning_cp,
- $1 as tree,
- NULL::text[] as cat_levels_in_text,
- NULL::integer[] as cat_n_levels,
- {depth} as tree_depth
+ {cp} AS pruning_cp,
+ $1 AS tree,
+ NULL::text[] AS cat_levels_in_text,
+ NULL::integer[] AS cat_n_levels,
+ $2 AS impurity_var_importance,
+ {depth} AS tree_depth
{fold}
""".format(depth=depth,
cp=tree_state['cp'],
fold=fold)
- sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib)])
- plpy.execute(sql_plan, [tree_state['tree_state']])
+ sql_plan = plpy.prepare(sql, ['{0}.bytea8'.format(schema_madlib),
+ 'double precision[]'])
+ plpy.execute(sql_plan, [tree_state['tree_state'],
+ tree_state['impurity_var_importance']])
# ------------------------------------------------------------
@@ -1007,8 +1031,8 @@ def _get_bins_grps(
# cat_origin = [0, 1, 4, 6, 8, 0, 1, 4, 6, 0, 1, 4, 6, 8]
# grp_key_cat = ['3', '4', '5']
cat_items_list = [rows[col]
- for grp_key, rows in grp_to_col_to_levels
- for col in cat_features if col in rows]
+ for grp_key, rows in grp_to_col_to_levels
+ for col in cat_features if col in rows]
cat_n = [len(i) for i in cat_items_list]
cat_origin = [item for sublist in cat_items_list for item in sublist]
grp_key_cat = [item[0] for item in grp_to_col_to_levels]
@@ -1097,28 +1121,21 @@ def _one_step(schema_madlib, training_table_name, cat_features,
"$3", "$2",
null_proxy)
- # The arguments of the aggregate (in the same order):
- # 1. current tree state, madlib.bytea8
- # 2. categorical features (integer format) in a single array
- # 3. continuous features in a single array
- # 4. weight value
- # 5. categorical sorted levels (integer format) in a combined array
- # 6. continuous splits
- # 7. number of dependent levels
train_sql = """
SELECT (result).* from (
SELECT
- {schema_madlib}._dt_apply($1,
+ {schema_madlib}._dt_apply(
+ $1,
{schema_madlib}._compute_leaf_stats(
- $1,
- {cat_features_str},
- {con_features_str},
+ $1, -- current tree state, madlib.bytea8
+ {cat_features_str}, -- categorical features in an array
+ {con_features_str}, -- continuous features in an array
{dep_var},
- {weights},
- $2,
- $4,
- {dep_n_levels}::smallint,
- {subsample}::boolean
+ {weights}, -- weight value
+ $2, -- categorical sorted levels in a combined array
+ $4, -- continuous splits
+ {dep_n_levels}::smallint, -- number of dependent levels
+ {subsample}::boolean -- should we use a subsample of data
),
$4,
{min_split}::smallint,
@@ -1364,12 +1381,14 @@ def _create_grp_result_table(
running_cv=False, k=0):
""" Create the output table for grouping case.
"""
+
grp_key = unique_string()
tree_state = unique_string()
tree_depth = unique_string()
cp_col = unique_string()
cat_levels_in_text = unique_string()
cat_n_levels = unique_string()
+ impurity_var_importance = unique_string()
cat_levels_val = cat_levels_in_text if cat_features else "NULL::TEXT[]"
cat_n_levels_val = cat_n_levels if cat_features else "NULL::INTEGER[]"
grouping_array_str = bins['grouping_array_str']
@@ -1383,25 +1402,28 @@ def _create_grp_result_table(
sql = header + """
SELECT
{grouping_cols},
- {tree_state} as tree,
- {cat_levels_val} as cat_levels_in_text,
- {cat_n_levels_val} as cat_n_levels,
- {tree_depth} as tree_depth,
- {cp_col} as pruning_cp
+ {tree_state} AS tree,
+ {cat_levels_val} AS cat_levels_in_text,
+ {cat_n_levels_val} AS cat_n_levels,
+ string_to_array(trim(both '{{}}' FROM {impurity_var_importance}),
+ ',')::double precision[] AS impurity_var_importance,
+ {tree_depth} AS tree_depth,
+ {cp_col} AS pruning_cp
{fold}
FROM (
SELECT
{grouping_cols},
- {grouping_array_str} as {grp_key}
+ {grouping_array_str} AS {grp_key}
FROM {training_table_name}
group by {grouping_cols}
) s1
JOIN (
SELECT
- unnest($1) as {grp_key},
- unnest($2) as {tree_state},
- unnest($3) as {tree_depth},
- unnest($4) as {cp_col}
+ unnest($1) AS {grp_key},
+ unnest($2) AS {tree_state},
+ unnest($3) AS {tree_depth},
+ unnest($4) AS {cp_col},
+ unnest($5) AS {impurity_var_importance}
) s2
USING ({grp_key})
"""
@@ -1413,41 +1435,30 @@ def _create_grp_result_table(
cat_n_levels as {cat_n_levels},
cat_levels_in_text as {cat_levels_in_text}
FROM
- {schema_madlib}._gen_cat_levels_set($5, $6, $7, $8)
-
+ {schema_madlib}._gen_cat_levels_set($6, $7, $8, $9)
) s3
USING ({grp_key})
"""
sql = sql.format(**locals())
+ prepare_list = ['text[]',
+ '{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib),
+ 'integer[]', 'double precision[]', 'text[]']
+ execute_list = [
+ [t['grp_key'] for t in tree_states],
+ [t['tree_state'] for t in tree_states],
+ [t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth']
+ for t in tree_states],
+ [t['cp'] for t in tree_states],
+ [_array_to_string(t['impurity_var_importance']) for t in tree_states]]
if cat_features:
- sql_plan = plpy.prepare(
- sql, ['text[]',
- '{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib),
- 'integer[]', 'double precision[]',
- 'text[]', 'integer[]', 'integer', 'text[]'])
- plpy.execute(sql_plan, [
- [t['grp_key'] for t in tree_states],
- [t['tree_state'] for t in tree_states],
- [t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth']
- for t in tree_states],
- [t['cp'] for t in tree_states],
+ prepare_list += ['text[]', 'integer[]', 'integer', 'text[]']
+ execute_list += [
bins['grp_key_cat'],
bins['cat_n'],
len(cat_features),
- bins['cat_origin']])
- else:
- sql_plan = plpy.prepare(sql, [
- 'text[]',
- '{schema_madlib}.bytea8[]'.format(schema_madlib=schema_madlib),
- 'integer[]',
- 'double precision[]'])
- plpy.execute(sql_plan, [
- [t['grp_key'] for t in tree_states],
- [t['tree_state'] for t in tree_states],
- [t['pruned_depth'] if 'pruned_depth' in t else t['tree_depth']
- for t in tree_states],
- [t['cp'] for t in tree_states]
- ])
+ bins['cat_origin']]
+ sql_plan = plpy.prepare(sql, prepare_list)
+ plpy.execute(sql_plan, execute_list)
# ------------------------------------------------------------
@@ -1476,6 +1487,7 @@ def _create_summary_table(
schema_madlib, split_criterion,
training_table_name, output_table_name, id_col_name,
cat_features, con_features, dependent_variable, list_of_features,
+ list_of_features_to_exclude,
num_failed_groups, is_classification, n_all_rows, n_rows,
dep_list, all_cols_types, cp, grouping_cols=None, n_groups=1,
use_existing_tables=False, n_folds=0, null_proxy=None):
@@ -1491,6 +1503,7 @@ def _create_summary_table(
dep_list_str = "NULL"
indep_type = ', '.join(all_cols_types[c] for c in cat_features + con_features)
dep_type = _get_dep_type(training_table_name, dependent_variable)
+ independent_varnames = ','.join(cat_features + con_features)
cat_features_str = ','.join(cat_features)
con_features_str = ','.join(con_features)
if grouping_cols:
@@ -1517,8 +1530,10 @@ def _create_summary_table(
'{training_table_name}'::text AS source_table,
'{output_table_name}'::text AS model_table,
'{id_col_name}'::text AS id_col_name,
+ '{list_of_features}'::text AS list_of_features,
+ '{list_of_features_to_exclude}'::text AS list_of_features_to_exclude,
'{dependent_variable}'::text AS dependent_varname,
- '{list_of_features}'::text AS independent_varnames,
+ '{independent_varnames}'::text AS independent_varnames,
'{cat_features_str}'::text AS cat_features,
'{con_features_str}'::text AS con_features,
{grouping_cols_str}::text AS grouping_cols,
@@ -1899,6 +1914,31 @@ def _prune_and_cplist(schema_madlib, tree, cp, compute_cp_list=False):
# -------------------------------------------------------------------------
+def _compute_var_importance(schema_madlib, tree,
+ n_cat_features, n_con_features):
+ """ Compute variable importance for categorical and continuous features
+
+ Args:
+ @param schema_madlib: str, MADlib schema name
+ @param tree: Tree data to prune
+ @param n_cat_features: int, Number of categorical features
+ @param n_con_features: int, Number of continuous features
+
+ Returns:
+ Dictionary containing following keys:
+ impurity_var_importance: Array of importance values
+ """
+ var_imp_sql = """
+ SELECT {schema_madlib}._get_var_importance(
+ $1, -- trained decision tree
+ {n_cat_features},
+ {n_con_features}) AS impurity_var_importance
+ """.format(**locals())
+ var_imp_plan = plpy.prepare(var_imp_sql, [schema_madlib + '.bytea8'])
+ return plpy.execute(var_imp_plan, [tree['tree_state']])[0]
+# ------------------------------------------------------------------------------
+
+
def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_name,
id_col_name, dependent_variable,
list_of_features, list_of_features_to_exclude,
@@ -2056,6 +2096,10 @@ def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_nam
tree['pruned_depth'] = pruned_tree['tree_depth']
else:
tree['pruned_depth'] = 0
+ importance_vectors = _compute_var_importance(
+ schema_madlib, tree,
+ len(cat_features), len(con_features))
+ tree.update(**importance_vectors)
plpy.execute("DROP TABLE {group_to_param_list_table}".format(**locals()))
# ------------------------------------------------------------
@@ -2157,13 +2201,14 @@ def _tree_train_using_bins(
tree_state = plpy.execute(
"""
SELECT {schema_madlib}._initialize_decision_tree(
- {is_regression_tree},
- '{split_criterion}'::text,
- {dep_n_levels}::smallint,
- {max_n_surr}::smallint)
- as tree_state, False as finished
+ {is_regression_tree},
+ '{split_criterion}'::text,
+ {dep_n_levels}::smallint,
+ {max_n_surr}::smallint
+ ) AS tree_state,
+ FALSE as finished
""".format(schema_madlib=schema_madlib,
- is_regression_tree=not is_classification,
+ is_regression_tree=(not is_classification),
split_criterion=split_criterion,
dep_n_levels=dep_n_levels,
max_n_surr=max_n_surr))[0]
@@ -2203,9 +2248,7 @@ def _tree_train_grps_using_bins(
{is_regression_tree},
'{split_criterion}'::text,
{dep_n_levels}::smallint,
- {max_n_surr}::smallint
- )
- AS tree_state,
+ {max_n_surr}::smallint) AS tree_state,
0 AS finished
""".format(schema_madlib=schema_madlib,
is_regression_tree=not is_classification,
http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
index 8e69d9b..9d5a5ef 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
@@ -279,6 +279,18 @@ tree_train(
for <em>weather_outlook</em> and 2 levels
<em>windy</em>.</td>
</tr>
+ <tr>
+ <th>impurity_var_importance</th>
+ <td>DOUBLE PRECISION[]. Impurity importance (also referred to as Gini
+ importance) of each variable. The order of the variables is the same as
+ that of 'independent_varnames' column in the summary table (see below).
+
+ The impurity importance of any feature is the decrease in impurity by a
+ node containing the feature as a primary split, summed over the whole
+ tree. If surrogates are used, then the importance value includes the
+ impurity decrease scaled by the adjusted surrogate agreement.
+ </td>
+ </tr>
<tr>
<th>tree_depth</th>
@@ -325,13 +337,24 @@ tree_train(
</tr>
<tr>
+ <th>list_of_features</th>
+ <td>TEXT. The list_of_features inputed to the 'tree_train' procedure.</td>
+ </tr>
+
+ <tr>
+ <th>list_of_features_to_exclude</th>
+ <td>TEXT. The list_of_features_to_exclude inputed to the 'tree_train' procedure.</td>
+ </tr>
+
+ <tr>
<th>dependent_varname</th>
<td>TEXT. The dependent variable.</td>
</tr>
<tr>
<th>independent_varnames</th>
- <td>TEXT. The independent variables.</td>
+ <td>TEXT. The independent variables. These are the features used in the
+ training of the decision tree.</td>
</tr>
<tr>
@@ -630,12 +653,17 @@ SELECT madlib.tree_train('dt_golf', -- source table
</pre>
View the output table (excluding the tree which is in binary format):
<pre class="example">
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+\\x on
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
</pre>
<pre class="result">
- pruning_cp | cat_levels_in_text | cat_n_levels | tree_depth
-------------+----------------------------------+--------------+------------
- 0 | {overcast,rain,sunny,False,True} | {3,2} | 5
+-[ RECORD 1 ]-----------+--------------------------------------
+pruning_cp | 0
+cat_levels_in_text | {overcast,rain,sunny,False,True}
+cat_n_levels | {3,2}
+impurity_var_importance | {10.6171201061712,0,89.3828798938288}
+tree_depth | 5
+
</pre>
View the summary table:
<pre class="example">
@@ -643,27 +671,29 @@ View the summary table:
SELECT * FROM train_output_summary;
</pre>
<pre class="result">
--[ RECORD 1 ]---------+--------------------------------
-method | tree_train
-is_classification | t
-source_table | dt_golf
-model_table | train_output
-id_col_name | id
-dependent_varname | class
-independent_varnames | "OUTLOOK", temperature, windy
-cat_features | "OUTLOOK",windy
-con_features | temperature
-grouping_cols |
-num_all_groups | 1
-num_failed_groups | 0
-total_rows_processed | 14
-total_rows_skipped | 0
-dependent_var_levels | "Don't Play","Play"
-dependent_var_type | text
-input_cp | 0
-independent_var_types | text, boolean, double precision
-n_folds | 0
-null_proxy |
+-[ RECORD 1 ]---------------+--------------------------------
+method | tree_train
+is_classification | t
+source_table | dt_golf
+model_table | train_output
+id_col_name | id
+list_of_features | "OUTLOOK", temperature, windy
+list_of_features_to_exclude | None
+dependent_varname | class
+independent_varnames | "OUTLOOK",windy,temperature
+cat_features | "OUTLOOK",windy
+con_features | temperature
+grouping_cols |
+num_all_groups | 1
+num_failed_groups | 0
+total_rows_processed | 14
+total_rows_skipped | 0
+dependent_var_levels | "Don't Play","Play"
+dependent_var_type | text
+input_cp | 0
+independent_var_types | text, boolean, double precision
+n_folds | 0
+null_proxy |
</pre>
-# Predict output categories. For the purpose
@@ -890,16 +920,20 @@ SELECT madlib.tree_train('dt_golf', -- source table
1, -- min bucket
10 -- number of bins per continuous variable
);
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
</pre>
View the output table (excluding the tree which is in binary format):
<pre class="example">
SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
</pre>
<pre class="result">
- pruning_cp | cat_levels_in_text | cat_n_levels | tree_depth
-------------+------------------------------------------------+--------------+------------
- 0 | {medium,none,high,low,unhealthy,good,moderate} | {4,3} | 3
+-[ RECORD 1 ]-----------+-----------------------------------------------------
+pruning_cp | 0
+cat_levels_in_text | {medium,none,high,low,unhealthy,good,moderate}
+cat_n_levels | {4,3}
+impurity_var_importance | {0,40.2340084993653,5.6791213643137,54.086870136321}
+tree_depth | 3
+
</pre>
The first 4 levels correspond to cloud ceiling and the next 3 levels
correspond to air quality.
@@ -1032,12 +1066,15 @@ SELECT madlib.tree_train('mt_cars', -- source table
View the output table (excluding the tree which is in binary format)
which shows ordering of levels of categorical variables 'vs' and 'cyl':
<pre class="example">
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
</pre>
<pre class="result">
- pruning_cp | cat_levels_in_text | cat_n_levels | tree_depth
-------------+--------------------+--------------+------------
- 0 | {0,1,4,6,8} | {2,3} | 4
+pruning_cp | 0
+cat_levels_in_text | {0,1,4,6,8}
+cat_n_levels | {2,3}
+impurity_var_importance | {0,51.8593201959496,10.976977929129,5.31897402755374,31.8447278473677}
+tree_depth | 4
+
</pre>
View the summary table:
<pre class="example">
@@ -1045,27 +1082,30 @@ View the summary table:
SELECT * FROM train_output_summary;
</pre>
<pre class="result">
--[ RECORD 1 ]---------+-----------------------------------------------------------------------
-method | tree_train
-is_classification | f
-source_table | mt_cars
-model_table | train_output
-id_col_name | id
-dependent_varname | mpg
-independent_varnames | *
-cat_features | vs,cyl
-con_features | disp,qsec,wt
-grouping_cols |
-num_all_groups | 1
-num_failed_groups | 0
-total_rows_processed | 32
-total_rows_skipped | 0
-dependent_var_levels |
-dependent_var_type | double precision
-input_cp | 0
-independent_var_types | integer, integer, double precision, double precision, double precision
-n_folds | 0
-null_proxy |
+-[ RECORD 1 ]---------------+-----------------------------------------------------------------------
+method | tree_train
+is_classification | f
+source_table | mt_cars
+model_table | train_output
+id_col_name | id
+list_of_features | *
+list_of_features_to_exclude | id, hp, drat, am, gear, carb
+dependent_varname | mpg
+independent_varnames | vs,cyl,disp,qsec,wt
+cat_features | vs,cyl
+con_features | disp,qsec,wt
+grouping_cols |
+num_all_groups | 1
+num_failed_groups | 0
+total_rows_processed | 32
+total_rows_skipped | 0
+dependent_var_levels |
+dependent_var_type | double precision
+input_cp | 0
+independent_var_types | integer, integer, double precision, double precision, double precision
+n_folds | 0
+null_proxy |
+
</pre>
-# Predict regression output for the same data and compare with original:
@@ -1212,12 +1252,16 @@ View the output table (excluding the tree which is in binary format).
The input cp value was 0 (default) and the best 'pruning_cp' value
turns out to be 0 as well in this small example:
<pre class="example">
-SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
+SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
</pre>
<pre class="result">
- pruning_cp | cat_levels_in_text | cat_n_levels | tree_depth
-------------+--------------------+--------------+------------
- 0 | {0,1,4,6,8} | {2,3} | 4
+-[ RECORD 1 ]-----------+-----------------------------------------------------------------------
+pruning_cp | 0
+cat_levels_in_text | {0,1,4,6,8}
+cat_n_levels | {2,3}
+impurity_var_importance | {0,51.8593201959496,10.976977929129,5.31897402755374,31.8447278473677}
+tree_depth | 4
+
</pre>
The cp values tested and average error and standard deviation are:
<pre class="example">
@@ -1288,27 +1332,29 @@ View the summary table:
SELECT * FROM train_output_summary;
</pre>
<pre class='result'>
--[ RECORD 1 ]---------+-----------------------
-method | tree_train
-is_classification | t
-source_table | null_handling_example
-model_table | train_output
-id_col_name | id
-dependent_varname | response
-independent_varnames | country, weather, city
-cat_features | country,weather,city
-con_features |
-grouping_cols |
-num_all_groups | 1
-num_failed_groups | 0
-total_rows_processed | 4
-total_rows_skipped | 0
-dependent_var_levels | "a","b","c","d"
-dependent_var_type | text
-input_cp | 0
-independent_var_types | text, text, text
-n_folds | 0
-null_proxy | __NULL__
+-[ RECORD 1 ]---------------+-----------------------
+method | tree_train
+is_classification | t
+source_table | null_handling_example
+model_table | train_output
+id_col_name | id
+list_of_features | country, weather, city
+list_of_features_to_exclude | None
+dependent_varname | response
+independent_varnames | country,weather,city
+cat_features | country,weather,city
+con_features |
+grouping_cols | [NULL]
+num_all_groups | 1
+num_failed_groups | 0
+total_rows_processed | 4
+total_rows_skipped | 0
+dependent_var_levels | "a","b","c","d"
+dependent_var_type | text
+input_cp | 0
+independent_var_types | text, text, text
+n_folds | 0
+null_proxy | __NULL__
</pre>
-# Predict for data not previously seen by assuming NULL
@@ -1697,8 +1743,8 @@ CREATE AGGREGATE MADLIB_SCHEMA._compute_leaf_stats(
DROP TYPE IF EXISTS MADLIB_SCHEMA._tree_result_type CASCADE;
CREATE TYPE MADLIB_SCHEMA._tree_result_type AS (
tree_state MADLIB_SCHEMA.BYTEA8,
- finished smallint, -- 0 running, 1 finished, 2 failed
- tree_depth smallint -- depth of the returned tree (0 = root node)
+ finished smallint, -- 0 running, 1 finished, 2 failed
+ tree_depth smallint -- depth of the returned tree (0 = root node)
);
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._dt_apply(
@@ -1785,6 +1831,16 @@ LANGUAGE C IMMUTABLE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
-------------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._get_var_importance(
+ tree MADLIB_SCHEMA.BYTEA8,
+ n_cat_features INTEGER,
+ n_con_features INTEGER
+) RETURNS DOUBLE PRECISION[] AS
+ 'MODULE_PATHNAME', 'get_variable_importance'
+LANGUAGE C IMMUTABLE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
+-------------------------------------------------------------------------
+
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._predict_dt_response(
tree MADLIB_SCHEMA.BYTEA8,
cat_features INTEGER[],
http://git-wip-us.apache.org/repos/asf/madlib/blob/b8031a03/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
index ba0c75e..dee3e32 100644
--- a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in
@@ -287,6 +287,7 @@ SELECT tree_train('dt_golf'::text, -- source table
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
+SELECT impurity_var_importance FROM train_output;
SELECT * FROM train_output_cv;
-------------------------------------------------------------------------
@@ -309,6 +310,8 @@ SELECT tree_train('dt_golf'::text, -- source table
);
SELECT _print_decision_tree(tree) from train_output;
+SELECT tree_display('train_output', FALSE);
+SELECT * FROM train_output;
SELECT tree_display('train_output', False);
@@ -354,6 +357,8 @@ SELECT tree_train('dt_golf'::text, -- source table
SELECT _print_decision_tree(tree) from train_output;
SELECT tree_display('train_output', False);
+SELECT tree_surr_display('train_output');
+SELECT * FROM train_output;
SELECT tree_predict('train_output', 'dt_golf', 'predict_output');
\x off
SELECT *
@@ -367,6 +372,31 @@ select * from train_output;
select * from train_output_summary;
-------------------------------------------------------------------------
+-- variable importance check
+DROP TABLE IF EXISTS train_output, train_output_summary, predict_output;
+SELECT tree_train('dt_golf'::text, -- source table
+ 'train_output'::text, -- output model table
+ 'id'::text, -- id column
+ 'temperature::double precision'::text, -- response
+ '"OUTLOOK", temperature'::text, -- features
+ NULL::text, -- exclude columns
+ 'mse'::text, -- split criterion
+ NULL::text, -- grouping col
+ NULL::text, -- no weights
+ NULL::integer, -- max depth
+ 6::integer, -- min split
+ 2::integer, -- min bucket
+ 3::integer, -- number of bins per continuous variable
+ 'cp=0.01',
+ ''
+ );
+
+SELECT impurity_var_importance FROM train_output;
+SELECT assert(impurity_var_importance[2] > 90,
+ 'Variable importance not valid for extreme case')
+FROM train_output;
+-------------------------------------------------------------------------
+
drop table if exists group_cp;
create table group_cp(class TEXT,
explore_value DOUBLE PRECISION);
@@ -405,9 +435,6 @@ select __build_tree(
);
select tree_display('train_output', FALSE);
-
-\d train_output_summary
-\x on
select * from train_output;
select * from train_output_summary;