You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2017/11/22 00:24:50 UTC

[2/2] madlib git commit: DT: Consolidate tree_rmse and tree_misclassified

DT: Consolidate tree_rmse and tree_misclassified


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/5b2ada5d
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/5b2ada5d
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/5b2ada5d

Branch: refs/heads/master
Commit: 5b2ada5d70022bdf1f3c8659b20b57b86310f4de
Parents: 4b7d9cc
Author: Rahul Iyer <ri...@apache.org>
Authored: Mon Nov 13 10:49:29 2017 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Tue Nov 21 16:22:27 2017 -0800

----------------------------------------------------------------------
 .../recursive_partitioning/decision_tree.py_in  | 569 +++++++++----------
 .../recursive_partitioning/decision_tree.sql_in |  24 +-
 .../modules/utilities/validate_args.py_in       |   6 +-
 3 files changed, 286 insertions(+), 313 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/5b2ada5d/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
index 88be1a2..2b9d434 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
@@ -186,167 +186,6 @@ def _classify_features(feature_to_type, features):
 # ------------------------------------------------------------
 
 
-def tree_train_help_message(schema_madlib, message, **kwargs):
-    """ Help message for Decision Tree
-    """
-    if not message:
-        help_string = """
-------------------------------------------------------------
-                        SUMMARY
-------------------------------------------------------------
-Functionality: Decision Tree
-
-Decision trees use a tree-based predictive model to
-predict the value of a target variable based on several input variables.
-
-For more details on the function usage:
-    SELECT {schema_madlib}.tree_train('usage');
-For an example on using this function:
-    SELECT {schema_madlib}.tree_train('example');
-        """
-    elif message.lower().strip() in ['usage', 'help', '?']:
-        help_string = """
-------------------------------------------------------------
-                        USAGE
-------------------------------------------------------------
-SELECT {schema_madlib}.tree_train(
-    'training_table',       -- Data table name
-    'output_table',         -- Table name to store the tree model
-    'id_col_name',          -- Row ID, used in tree_predict
-    'dependent_variable',   -- The column to fit
-    'list_of_features',     -- Comma separated column names to be
-                                used as the predictors, can be '*'
-                                to include all columns except the
-                                dependent_variable
-    'features_to_exclude',  -- Comma separated column names to be
-                                excluded if list_of_features is '*'
-    'split_criterion',      -- How to split a node, options are
-                                'gini', 'misclassification' and
-                                'entropy' for classification, and
-                                'mse' for regression.
-    'grouping_cols',        -- Comma separated column names used to
-                                group the data. A decision tree model
-                                will be created for each group. Default
-                                is NULL
-    'weights',              -- A Column name containing weights for
-                                each observation. Default is NULL
-    max_depth,              -- Maximum depth of any node, default is 7
-    min_split,              -- Minimum number of observations that must
-                                exist in a node for a split to be
-                                attemped, default is 20
-    min_bucket,             -- Minimum number of observations in any
-                                terminal node, default is min_split/3
-    n_bins,                 -- Number of bins to find possible node
-                                split threshold values for continuous
-                                variables, default is 20 (Must be greater than 1)
-    pruning_params,         -- A comma-separated text containing
-                                key=value pairs of parameters for pruning.
-                                Parameters accepted:
-                                    'cp' - complexity parameter with default=0.01,
-                                    'n_folds' - number of cross-validation folds
-                                        with default value of 0 (= no cross-validation)
-    null_handling_params,   -- A comma-separated text containing
-                                key=value pairs of parameters for handling NULL values.
-                                Parameters accepted:
-                                    'max_surrogates' - Maximum number of surrogates to
-                                        compute for each split
-                                    'null_as_category' - Boolean to indicate if
-                                        NULL should be treated as a special category
-    verbose                 -- Boolean, whether to print more info, default is False
-);
-
-------------------------------------------------------------
-                        OUTPUT
-------------------------------------------------------------
-The output table ('output_table' above) has the following columns (quoted items
-are of type TEXT):
-    <grouping columns>      -- Grouping columns, only present when
-                                'grouping_cols' is not NULL or ''
-    tree                    -- The decision tree model as a binary string
-    cat_levels_in_text      -- Distinct levels (casted to text) of all
-                                categorical variables combined in a single array
-    cat_n_levels            -- Number of distinct levels of all categorical variables
-    tree_depth              -- Number of levels in the tree (root has level 0)
-    pruning_cp              -- The cost-complexity parameter used for pruning
-                                the trained tree(s). This would be different
-                                from the input cp value if cross-validation is used.
-
-The output summary table ('output_table_summary') has the following columns:
-    'method'                -- Method name: 'tree_train'
-    'source_table'          -- Data table name
-    'model_table'           -- Tree model table name
-    'id_col_name'           -- Name of the 'id' column
-    is_classification       -- Boolean value indicating if tree is classification or regression
-    'dependent_varname'     -- Response variable column name
-    'independent_varnames'  -- Comma-separated feature column names
-    'cat_features'          -- Comma-separated column names of categorical variables
-    'con_features'          -- Comma-separated column names of continuous variables
-    'grouping_cols'         -- Grouping column names
-    num_all_groups          -- Number of groups
-    num_failed_groups       -- Number of groups for which training failed
-    total_rows_processed    -- Number of rows used in the model training
-    total_rows_skipped      -- Number of rows skipped because NULL values
-    dependent_var_levels    -- For classification, the distinct levels of
-                                the dependent variable
-    dependent_var_type      -- The type of dependent variable
-    input_cp                -- The complexity parameter (cp) used for pruning the
-                                 trained tree(s) (before cross-validation is run)
-    independent_var_types   -- The types of independent variables, comma-separated
-    k                       -- Number of folds (NULL if not using cross validation)
-    null_proxy              -- String used as replacement for NULL values
-                                (NULL if null_as_category = False)
-
-        """
-    elif message.lower().strip() in ['example', 'examples']:
-        help_string = """
-------------------------------------------------------------
-                        EXAMPLE
-------------------------------------------------------------
-DROP TABLE IF EXISTS dummy_dt_con_src CASCADE;
-CREATE TABLE dummy_dt_con_src (
-    id  INTEGER,
-    cat INTEGER[],
-    con FLOAT8[],
-    y   FLOAT8
-);
-
-INSERT INTO dummy_dt_src VALUES
-(1, '{{0}}'::INTEGER[], ARRAY[0], 0.5),
-(2, '{{0}}'::INTEGER[], ARRAY[1], 0.5),
-(3, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
-(4, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
-(5, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
-(6, '{{0}}'::INTEGER[], ARRAY[5], 0.1),
-(7, '{{0}}'::INTEGER[], ARRAY[6], 0.1),
-(8, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-(9, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-(10, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-(11, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-
-DROP TABLE IF EXISTS tree_out, tree_out_summary;
-SELECT madlib.tree_train(
-    'dummy_dt_src',
-    'tree_out',
-    'id',
-    'y',
-    'cat, con',
-    '',
-    'mse',
-    NULL::Text,
-    NULL::Text,
-    3,
-    2,
-    1,
-    5);
-
-SELECT madlib.tree_display('tree_out');
-        """
-    else:
-        help_string = "No such option. Use {schema_madlib}.tree_train('usage')"
-    return help_string.format(schema_madlib=schema_madlib)
-# ------------------------------------------------------------
-
-
 def _extract_pruning_params(pruning_params_str):
     """
     Args:
@@ -1979,77 +1818,6 @@ def tree_display(schema_madlib, model_table, dot_format=True, verbose=False,
 # -------------------------------------------------------------------------
 
 
-def tree_predict_help_message(schema_madlib, message, **kwargs):
-    """ Help message for Decision Tree predict
-    """
-    if not message:
-        help_string = """
-------------------------------------------------------------
-                        SUMMARY
-------------------------------------------------------------
-Functionality: Decision Tree Prediction
-
-Prediction for a decision tree (trained using {schema_madlib}.tree_train) can
-be performed on a new data table.
-
-For more details on the function usage:
-    SELECT {schema_madlib}.tree_predict('usage');
-For an example on using this function:
-    SELECT {schema_madlib}.tree_predict('example');
-        """
-    elif message.lower().strip() in ['usage', 'help', '?']:
-        help_string = """
-------------------------------------------------------------
-                        USAGE
-------------------------------------------------------------
-SELECT {schema_madlib}.tree_predict(
-    'tree_model',           -- Model table name (output of tree_train)
-    'new_data_table',       -- Prediction source table
-    'output_table',         -- Table name to store the predictions
-    'type'                  -- Type of prediction output
-);
-
-Note: The 'new_data_table' should have the same 'id_col_name' column as used
-in the training function. This is used to corelate the prediction data row with
-the actual prediction in the output table.
-
-------------------------------------------------------------
-                        OUTPUT
-------------------------------------------------------------
-The output table ('output_table' above) has the '<id_col_name>' column giving
-the 'id' for each prediction and the prediction columns for the response
-variable (also called as dependent variable).
-
-If prediction type = 'response', then the table has a single column with the
-prediction value of the response. The type of this column depends on the type
-of the response variable used during training.
-
-If prediction type = 'prob', then the table has multiple columns, one for each
-possible value of the response variable. The columns are labeled as
-'estimated_prob_<dep value>', where <dep value> represents for each value
-of the response.
-        """
-    elif message.lower().strip() in ['example', 'examples']:
-        help_string = """
-------------------------------------------------------------
-                        EXAMPLE
-------------------------------------------------------------
--- Assuming the example of tree_train() has been run
-SELECT {schema_madlib}.tree_predict(
-    'tree_out',
-    'dummy_dt_src',
-    'tree_predict_out',
-    'response'
-);
-
-SELECT * FROM tree_predict_out;
-        """
-    else:
-        help_string = "No such option. Use {schema_madlib}.tree_predict('usage')"
-    return help_string.format(schema_madlib=schema_madlib)
-# ------------------------------------------------------------
-
-
 def _prune_and_cplist(schema_madlib, tree, cp, compute_cp_list=False):
     """ Prune tree with given cost-complexity parameters
         and return a list of cp values at which tree can be pruned
@@ -2425,87 +2193,274 @@ def _tree_train_grps_using_bins(
 # ------------------------------------------------------------
 
 
-def _tree_rmse(schema_madlib, source_table, dependent_varname, prediction_table,
-               pred_dep_name, id_col_name, grouping_cols, output_table,
-               use_existing_tables=False, k=0, **kwargs):
-    old_messages = plpy.execute("SELECT setting FROM pg_settings "
-                                "WHERE name = 'client_min_messages'")[0]['setting']
-    plpy.execute('SET client_min_messages TO warning')
+def _tree_error(schema_madlib, source_table, dependent_varname,
+                prediction_table, pred_dep_name, id_col_name, grouping_cols,
+                output_table, is_classification,
+                use_existing_tables=False, k=0, **kwargs):
+    with MinWarning("warning"):
+        if use_existing_tables and table_exists(output_table):
+            # plpy.execute("truncate " + output_table)
+            header = "INSERT INTO " + output_table + " "
+        else:
+            header = "CREATE TABLE " + output_table + " AS "
+        if is_classification:
+            error_func = """
+                1.0 * sum(CASE WHEN ({prediction_table}.{pred_dep_name} =
+                                     {source_table}.{dependent_varname})
+                               THEN 0
+                               ELSE 1
+                          END) / count(*)
+                """.format(**locals())
+        else:
+            error_func = """
+                sqrt(avg(({prediction_table}.{pred_dep_name} -
+                          {source_table}.{dependent_varname})^2
+                        )
+                    )
+                """.format(**locals())
+        grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols
+        grouping_col_str = '' if not grouping_cols else grouping_cols + ','
+
+        sql = header + """
+            SELECT
+                {grouping_col_str}
+                {error_func} as cv_error,
+                {k} as k
+            FROM {prediction_table}, {source_table}
+            WHERE {prediction_table}.{id_col_name} = {source_table}.{id_col_name}
+            {grouping_str}
+            """.format(**locals())
+        plpy.execute(sql)
+# ------------------------------------------------------------
+
+
+def tree_train_help_message(schema_madlib, message, **kwargs):
+    """ Help message for Decision Tree
+    """
+    if not message:
+        help_string = """
+------------------------------------------------------------
+                        SUMMARY
+------------------------------------------------------------
+Functionality: Decision Tree
 
-    fold = ", " + str(k) + " as k"
-    if use_existing_tables and table_exists(output_table):
-        # plpy.execute("truncate " + output_table)
-        header = "INSERT INTO " + output_table + " "
-    else:
-        header = "CREATE TABLE " + output_table + " AS "
+Decision trees use a tree-based predictive model to
+predict the value of a target variable based on several input variables.
 
-    grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols
-    grouping_col_str = "" if not grouping_cols else grouping_cols + ','
+For more details on the function usage:
+    SELECT {schema_madlib}.tree_train('usage');
+For an example on using this function:
+    SELECT {schema_madlib}.tree_train('example');
+        """
+    elif message.lower().strip() in ['usage', 'help', '?']:
+        help_string = """
+------------------------------------------------------------
+                        USAGE
+------------------------------------------------------------
+SELECT {schema_madlib}.tree_train(
+    'training_table',       -- Data table name
+    'output_table',         -- Table name to store the tree model
+    'id_col_name',          -- Row ID, used in tree_predict
+    'dependent_variable',   -- The column to fit
+    'list_of_features',     -- Comma separated column names to be
+                                used as the predictors, can be '*'
+                                to include all columns except the
+                                dependent_variable
+    'features_to_exclude',  -- Comma separated column names to be
+                                excluded if list_of_features is '*'
+    'split_criterion',      -- How to split a node, options are
+                                'gini', 'misclassification' and
+                                'entropy' for classification, and
+                                'mse' for regression.
+    'grouping_cols',        -- Comma separated column names used to
+                                group the data. A decision tree model
+                                will be created for each group. Default
+                                is NULL
+    'weights',              -- A Column name containing weights for
+                                each observation. Default is NULL
+    max_depth,              -- Maximum depth of any node, default is 7
+    min_split,              -- Minimum number of observations that must
+                                exist in a node for a split to be
+                                attemped, default is 20
+    min_bucket,             -- Minimum number of observations in any
+                                terminal node, default is min_split/3
+    n_bins,                 -- Number of bins to find possible node
+                                split threshold values for continuous
+                                variables, default is 20 (Must be greater than 1)
+    pruning_params,         -- A comma-separated text containing
+                                key=value pairs of parameters for pruning.
+                                Parameters accepted:
+                                    'cp' - complexity parameter with default=0.01,
+                                    'n_folds' - number of cross-validation folds
+                                        with default value of 0 (= no cross-validation)
+    null_handling_params,   -- A comma-separated text containing
+                                key=value pairs of parameters for handling NULL values.
+                                Parameters accepted:
+                                    'max_surrogates' - Maximum number of surrogates to
+                                        compute for each split
+                                    'null_as_category' - Boolean to indicate if
+                                        NULL should be treated as a special category
+    verbose                 -- Boolean, whether to print more info, default is False
+);
 
-    sql = header + """
-        SELECT
-            {grouping_col_str}
-            sqrt(avg(
-                    ({prediction_table}.{pred_dep_name} -
-                     {source_table}.{dependent_varname})^2
-                    )
-                ) as cv_error
-            {fold}
-        FROM {prediction_table}, {source_table}
-        WHERE {prediction_table}.{id_col_name} = {source_table}.{id_col_name}
-        {grouping_str}
-        """.format(output_table=output_table,
-                   prediction_table=prediction_table,
-                   pred_dep_name=pred_dep_name,
-                   source_table=source_table,
-                   dependent_varname=dependent_varname,
-                   grouping_col_str=grouping_col_str,
-                   grouping_str=grouping_str,
-                   id_col_name=id_col_name,
-                   fold=fold)
-    plpy.execute(sql)
-    plpy.execute('SET client_min_messages TO ' + old_messages)
-# ------------------------------------------------------------
+------------------------------------------------------------
+                        OUTPUT
+------------------------------------------------------------
+The output table ('output_table' above) has the following columns (quoted items
+are of type TEXT):
+    <grouping columns>      -- Grouping columns, only present when
+                                'grouping_cols' is not NULL or ''
+    tree                    -- The decision tree model as a binary string
+    cat_levels_in_text      -- Distinct levels (casted to text) of all
+                                categorical variables combined in a single array
+    cat_n_levels            -- Number of distinct levels of all categorical variables
+    tree_depth              -- Number of levels in the tree (root has level 0)
+    pruning_cp              -- The cost-complexity parameter used for pruning
+                                the trained tree(s). This would be different
+                                from the input cp value if cross-validation is used.
+
+The output summary table ('output_table_summary') has the following columns:
+    'method'                -- Method name: 'tree_train'
+    'source_table'          -- Data table name
+    'model_table'           -- Tree model table name
+    'id_col_name'           -- Name of the 'id' column
+    is_classification       -- Boolean value indicating if tree is classification or regression
+    'dependent_varname'     -- Response variable column name
+    'independent_varnames'  -- Comma-separated feature column names
+    'cat_features'          -- Comma-separated column names of categorical variables
+    'con_features'          -- Comma-separated column names of continuous variables
+    'grouping_cols'         -- Grouping column names
+    num_all_groups          -- Number of groups
+    num_failed_groups       -- Number of groups for which training failed
+    total_rows_processed    -- Number of rows used in the model training
+    total_rows_skipped      -- Number of rows skipped because NULL values
+    dependent_var_levels    -- For classification, the distinct levels of
+                                the dependent variable
+    dependent_var_type      -- The type of dependent variable
+    input_cp                -- The complexity parameter (cp) used for pruning the
+                                 trained tree(s) (before cross-validation is run)
+    independent_var_types   -- The types of independent variables, comma-separated
+    k                       -- Number of folds (NULL if not using cross validation)
+    null_proxy              -- String used as replacement for NULL values
+                                (NULL if null_as_category = False)
+
+        """
+    elif message.lower().strip() in ['example', 'examples']:
+        help_string = """
+------------------------------------------------------------
+                        EXAMPLE
+------------------------------------------------------------
+DROP TABLE IF EXISTS dummy_dt_con_src CASCADE;
+CREATE TABLE dummy_dt_con_src (
+    id  INTEGER,
+    cat INTEGER[],
+    con FLOAT8[],
+    y   FLOAT8
+);
 
+INSERT INTO dummy_dt_src VALUES
+(1, '{{0}}'::INTEGER[], ARRAY[0], 0.5),
+(2, '{{0}}'::INTEGER[], ARRAY[1], 0.5),
+(3, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
+(4, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
+(5, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
+(6, '{{0}}'::INTEGER[], ARRAY[5], 0.1),
+(7, '{{0}}'::INTEGER[], ARRAY[6], 0.1),
+(8, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
+(9, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
+(10, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
+(11, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
 
-def _tree_misclassified(schema_madlib, source_table, dependent_varname,
-                        prediction_table, pred_dep_name, id_col_name,
-                        grouping_cols, output_table, use_existing_tables=False,
-                        k=0, **kwargs):
-    old_messages = plpy.execute(
-        "SELECT setting FROM pg_settings WHERE name = 'client_min_messages'")[0]['setting']
-    plpy.execute('SET client_min_messages TO warning')
+DROP TABLE IF EXISTS tree_out, tree_out_summary;
+SELECT madlib.tree_train(
+    'dummy_dt_src',
+    'tree_out',
+    'id',
+    'y',
+    'cat, con',
+    '',
+    'mse',
+    NULL::Text,
+    NULL::Text,
+    3,
+    2,
+    1,
+    5);
 
-    fold = ", " + str(k) + " as k"
-    if use_existing_tables and table_exists(output_table):
-        # plpy.execute("truncate " + output_table)
-        header = "INSERT INTO " + output_table + " "
+SELECT madlib.tree_display('tree_out');
+        """
     else:
-        header = "CREATE TABLE " + output_table + " AS "
+        help_string = "No such option. Use {schema_madlib}.tree_train('usage')"
+    return help_string.format(schema_madlib=schema_madlib)
+# ------------------------------------------------------------
 
-    grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols
-    grouping_col_str = "" if not grouping_cols else grouping_cols + ','
 
-    sql = header + """
-        SELECT
-            {grouping_col_str}
-            1.0 * sum(CASE WHEN ({prediction_table}.{pred_dep_name} =
-                           {source_table}.{dependent_varname})
-                THEN 0 ELSE 1 END) / count(*) as cv_error
-            {fold}
-        FROM
-            {prediction_table}, {source_table}
-        WHERE
-          {prediction_table}.{id_col_name} = {source_table}.{id_col_name}
-        {grouping_str}
-        """.format(output_table=output_table,
-                   prediction_table=prediction_table,
-                   pred_dep_name=pred_dep_name,
-                   source_table=source_table,
-                   dependent_varname=dependent_varname,
-                   grouping_col_str=grouping_col_str,
-                   grouping_str=grouping_str,
-                   id_col_name=id_col_name,
-                   fold=fold)
-    plpy.execute(sql)
-    plpy.execute('SET client_min_messages TO ' + old_messages)
+def tree_predict_help_message(schema_madlib, message, **kwargs):
+    """ Help message for Decision Tree predict
+    """
+    if not message:
+        help_string = """
+------------------------------------------------------------
+                        SUMMARY
+------------------------------------------------------------
+Functionality: Decision Tree Prediction
+
+Prediction for a decision tree (trained using {schema_madlib}.tree_train) can
+be performed on a new data table.
+
+For more details on the function usage:
+    SELECT {schema_madlib}.tree_predict('usage');
+For an example on using this function:
+    SELECT {schema_madlib}.tree_predict('example');
+        """
+    elif message.lower().strip() in ['usage', 'help', '?']:
+        help_string = """
+------------------------------------------------------------
+                        USAGE
+------------------------------------------------------------
+SELECT {schema_madlib}.tree_predict(
+    'tree_model',           -- Model table name (output of tree_train)
+    'new_data_table',       -- Prediction source table
+    'output_table',         -- Table name to store the predictions
+    'type'                  -- Type of prediction output
+);
+
+Note: The 'new_data_table' should have the same 'id_col_name' column as used
+in the training function. This is used to corelate the prediction data row with
+the actual prediction in the output table.
+
+------------------------------------------------------------
+                        OUTPUT
+------------------------------------------------------------
+The output table ('output_table' above) has the '<id_col_name>' column giving
+the 'id' for each prediction and the prediction columns for the response
+variable (also called as dependent variable).
+
+If prediction type = 'response', then the table has a single column with the
+prediction value of the response. The type of this column depends on the type
+of the response variable used during training.
+
+If prediction type = 'prob', then the table has multiple columns, one for each
+possible value of the response variable. The columns are labeled as
+'estimated_prob_<dep value>', where <dep value> represents for each value
+of the response.
+        """
+    elif message.lower().strip() in ['example', 'examples']:
+        help_string = """
+------------------------------------------------------------
+                        EXAMPLE
+------------------------------------------------------------
+-- Assuming the example of tree_train() has been run
+SELECT {schema_madlib}.tree_predict(
+    'tree_out',
+    'dummy_dt_src',
+    'tree_predict_out',
+    'response'
+);
+
+SELECT * FROM tree_predict_out;
+        """
+    else:
+        help_string = "No such option. Use {schema_madlib}.tree_predict('usage')"
+    return help_string.format(schema_madlib=schema_madlib)
+# ------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/madlib/blob/5b2ada5d/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
index 91e900d..c18fba9 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
@@ -1947,7 +1947,11 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_rmse(
     grouping_cols          TEXT,
     output_table           VARCHAR
 ) RETURNS VOID AS $$
-    PythonFunction(recursive_partitioning, decision_tree, _tree_rmse)
+    PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+    decision_tree._tree_error(
+        schema_madlib, source_table, dependent_varname,
+        prediction_table, pred_dep_name, id_col_name, grouping_cols,
+        output_table, False)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -1962,7 +1966,11 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_rmse(
     use_existing_tables    BOOLEAN,
     k                      INTEGER
 ) RETURNS VOID AS $$
-    PythonFunction(recursive_partitioning, decision_tree, _tree_rmse)
+    PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+    decision_tree._tree_error(
+        schema_madlib, source_table, dependent_varname,
+        prediction_table, pred_dep_name, id_col_name, grouping_cols,
+        output_table, False, use_existing_tables, k)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 -------------------------------------------------------------------------
@@ -1977,7 +1985,11 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_misclassified(
     grouping_cols          TEXT,
     output_table           VARCHAR
 ) RETURNS VOID AS $$
-    PythonFunction(recursive_partitioning, decision_tree, _tree_misclassified)
+    PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+    decision_tree._tree_error(
+        schema_madlib, source_table, dependent_varname,
+        prediction_table, pred_dep_name, id_col_name, grouping_cols,
+        output_table, True)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -1992,6 +2004,10 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_misclassified(
     use_existing_tables    BOOLEAN,
     k                      INTEGER
 ) RETURNS VOID AS $$
-    PythonFunction(recursive_partitioning, decision_tree, _tree_misclassified)
+    PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+    decision_tree._tree_error(
+        schema_madlib, source_table, dependent_varname,
+        prediction_table, pred_dep_name, id_col_name, grouping_cols,
+        output_table, True, use_existing_tables, k)
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');

http://git-wip-us.apache.org/repos/asf/madlib/blob/5b2ada5d/src/ports/postgres/modules/utilities/validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index baefa5d..2b9c6d7 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -553,7 +553,7 @@ def _tbl_dimension_rownum(schema_madlib, tbl_source, col_ind_var):
                      SELECT array_upper({col_ind_var},1) AS dimension
                      FROM {tbl_source} LIMIT 1
                  """.format(tbl_source=tbl_source,
-                        col_ind_var=col_ind_var))[0]["dimension"]
+                            col_ind_var=col_ind_var))[0]["dimension"]
     # total row number of data source table
     # The WHERE clause here ignores rows in the table that contain one or more
     # NULLs in the independent variable (x). There is no NULL check made for
@@ -569,6 +569,7 @@ def _tbl_dimension_rownum(schema_madlib, tbl_source, col_ind_var):
     return (dimension, row_num)
 # ------------------------------------------------------------------------
 
+
 def is_var_valid(tbl, var, order_by=None):
     """
     Test whether the variable(s) is valid by actually selecting it from
@@ -585,7 +586,8 @@ def is_var_valid(tbl, var, order_by=None):
             {order_by_str}
             LIMIT 0
             """.format(**locals()))
-    except Exception:
+    except Exception as e:
+        plpy.warning(str(e))
         return False
     return True
 # -------------------------------------------------------------------------