You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2017/11/22 00:24:50 UTC
[2/2] madlib git commit: DT: Consolidate tree_rmse and
tree_misclassified
DT: Consolidate tree_rmse and tree_misclassified
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/5b2ada5d
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/5b2ada5d
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/5b2ada5d
Branch: refs/heads/master
Commit: 5b2ada5d70022bdf1f3c8659b20b57b86310f4de
Parents: 4b7d9cc
Author: Rahul Iyer <ri...@apache.org>
Authored: Mon Nov 13 10:49:29 2017 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Tue Nov 21 16:22:27 2017 -0800
----------------------------------------------------------------------
.../recursive_partitioning/decision_tree.py_in | 569 +++++++++----------
.../recursive_partitioning/decision_tree.sql_in | 24 +-
.../modules/utilities/validate_args.py_in | 6 +-
3 files changed, 286 insertions(+), 313 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/5b2ada5d/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
index 88be1a2..2b9d434 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
@@ -186,167 +186,6 @@ def _classify_features(feature_to_type, features):
# ------------------------------------------------------------
-def tree_train_help_message(schema_madlib, message, **kwargs):
- """ Help message for Decision Tree
- """
- if not message:
- help_string = """
-------------------------------------------------------------
- SUMMARY
-------------------------------------------------------------
-Functionality: Decision Tree
-
-Decision trees use a tree-based predictive model to
-predict the value of a target variable based on several input variables.
-
-For more details on the function usage:
- SELECT {schema_madlib}.tree_train('usage');
-For an example on using this function:
- SELECT {schema_madlib}.tree_train('example');
- """
- elif message.lower().strip() in ['usage', 'help', '?']:
- help_string = """
-------------------------------------------------------------
- USAGE
-------------------------------------------------------------
-SELECT {schema_madlib}.tree_train(
- 'training_table', -- Data table name
- 'output_table', -- Table name to store the tree model
- 'id_col_name', -- Row ID, used in tree_predict
- 'dependent_variable', -- The column to fit
- 'list_of_features', -- Comma separated column names to be
- used as the predictors, can be '*'
- to include all columns except the
- dependent_variable
- 'features_to_exclude', -- Comma separated column names to be
- excluded if list_of_features is '*'
- 'split_criterion', -- How to split a node, options are
- 'gini', 'misclassification' and
- 'entropy' for classification, and
- 'mse' for regression.
- 'grouping_cols', -- Comma separated column names used to
- group the data. A decision tree model
- will be created for each group. Default
- is NULL
- 'weights', -- A Column name containing weights for
- each observation. Default is NULL
- max_depth, -- Maximum depth of any node, default is 7
- min_split, -- Minimum number of observations that must
- exist in a node for a split to be
- attemped, default is 20
- min_bucket, -- Minimum number of observations in any
- terminal node, default is min_split/3
- n_bins, -- Number of bins to find possible node
- split threshold values for continuous
- variables, default is 20 (Must be greater than 1)
- pruning_params, -- A comma-separated text containing
- key=value pairs of parameters for pruning.
- Parameters accepted:
- 'cp' - complexity parameter with default=0.01,
- 'n_folds' - number of cross-validation folds
- with default value of 0 (= no cross-validation)
- null_handling_params, -- A comma-separated text containing
- key=value pairs of parameters for handling NULL values.
- Parameters accepted:
- 'max_surrogates' - Maximum number of surrogates to
- compute for each split
- 'null_as_category' - Boolean to indicate if
- NULL should be treated as a special category
- verbose -- Boolean, whether to print more info, default is False
-);
-
-------------------------------------------------------------
- OUTPUT
-------------------------------------------------------------
-The output table ('output_table' above) has the following columns (quoted items
-are of type TEXT):
- <grouping columns> -- Grouping columns, only present when
- 'grouping_cols' is not NULL or ''
- tree -- The decision tree model as a binary string
- cat_levels_in_text -- Distinct levels (casted to text) of all
- categorical variables combined in a single array
- cat_n_levels -- Number of distinct levels of all categorical variables
- tree_depth -- Number of levels in the tree (root has level 0)
- pruning_cp -- The cost-complexity parameter used for pruning
- the trained tree(s). This would be different
- from the input cp value if cross-validation is used.
-
-The output summary table ('output_table_summary') has the following columns:
- 'method' -- Method name: 'tree_train'
- 'source_table' -- Data table name
- 'model_table' -- Tree model table name
- 'id_col_name' -- Name of the 'id' column
- is_classification -- Boolean value indicating if tree is classification or regression
- 'dependent_varname' -- Response variable column name
- 'independent_varnames' -- Comma-separated feature column names
- 'cat_features' -- Comma-separated column names of categorical variables
- 'con_features' -- Comma-separated column names of continuous variables
- 'grouping_cols' -- Grouping column names
- num_all_groups -- Number of groups
- num_failed_groups -- Number of groups for which training failed
- total_rows_processed -- Number of rows used in the model training
- total_rows_skipped -- Number of rows skipped because NULL values
- dependent_var_levels -- For classification, the distinct levels of
- the dependent variable
- dependent_var_type -- The type of dependent variable
- input_cp -- The complexity parameter (cp) used for pruning the
- trained tree(s) (before cross-validation is run)
- independent_var_types -- The types of independent variables, comma-separated
- k -- Number of folds (NULL if not using cross validation)
- null_proxy -- String used as replacement for NULL values
- (NULL if null_as_category = False)
-
- """
- elif message.lower().strip() in ['example', 'examples']:
- help_string = """
-------------------------------------------------------------
- EXAMPLE
-------------------------------------------------------------
-DROP TABLE IF EXISTS dummy_dt_con_src CASCADE;
-CREATE TABLE dummy_dt_con_src (
- id INTEGER,
- cat INTEGER[],
- con FLOAT8[],
- y FLOAT8
-);
-
-INSERT INTO dummy_dt_src VALUES
-(1, '{{0}}'::INTEGER[], ARRAY[0], 0.5),
-(2, '{{0}}'::INTEGER[], ARRAY[1], 0.5),
-(3, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
-(4, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
-(5, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
-(6, '{{0}}'::INTEGER[], ARRAY[5], 0.1),
-(7, '{{0}}'::INTEGER[], ARRAY[6], 0.1),
-(8, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-(9, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-(10, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-(11, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-
-DROP TABLE IF EXISTS tree_out, tree_out_summary;
-SELECT madlib.tree_train(
- 'dummy_dt_src',
- 'tree_out',
- 'id',
- 'y',
- 'cat, con',
- '',
- 'mse',
- NULL::Text,
- NULL::Text,
- 3,
- 2,
- 1,
- 5);
-
-SELECT madlib.tree_display('tree_out');
- """
- else:
- help_string = "No such option. Use {schema_madlib}.tree_train('usage')"
- return help_string.format(schema_madlib=schema_madlib)
-# ------------------------------------------------------------
-
-
def _extract_pruning_params(pruning_params_str):
"""
Args:
@@ -1979,77 +1818,6 @@ def tree_display(schema_madlib, model_table, dot_format=True, verbose=False,
# -------------------------------------------------------------------------
-def tree_predict_help_message(schema_madlib, message, **kwargs):
- """ Help message for Decision Tree predict
- """
- if not message:
- help_string = """
-------------------------------------------------------------
- SUMMARY
-------------------------------------------------------------
-Functionality: Decision Tree Prediction
-
-Prediction for a decision tree (trained using {schema_madlib}.tree_train) can
-be performed on a new data table.
-
-For more details on the function usage:
- SELECT {schema_madlib}.tree_predict('usage');
-For an example on using this function:
- SELECT {schema_madlib}.tree_predict('example');
- """
- elif message.lower().strip() in ['usage', 'help', '?']:
- help_string = """
-------------------------------------------------------------
- USAGE
-------------------------------------------------------------
-SELECT {schema_madlib}.tree_predict(
- 'tree_model', -- Model table name (output of tree_train)
- 'new_data_table', -- Prediction source table
- 'output_table', -- Table name to store the predictions
- 'type' -- Type of prediction output
-);
-
-Note: The 'new_data_table' should have the same 'id_col_name' column as used
-in the training function. This is used to corelate the prediction data row with
-the actual prediction in the output table.
-
-------------------------------------------------------------
- OUTPUT
-------------------------------------------------------------
-The output table ('output_table' above) has the '<id_col_name>' column giving
-the 'id' for each prediction and the prediction columns for the response
-variable (also called as dependent variable).
-
-If prediction type = 'response', then the table has a single column with the
-prediction value of the response. The type of this column depends on the type
-of the response variable used during training.
-
-If prediction type = 'prob', then the table has multiple columns, one for each
-possible value of the response variable. The columns are labeled as
-'estimated_prob_<dep value>', where <dep value> represents for each value
-of the response.
- """
- elif message.lower().strip() in ['example', 'examples']:
- help_string = """
-------------------------------------------------------------
- EXAMPLE
-------------------------------------------------------------
--- Assuming the example of tree_train() has been run
-SELECT {schema_madlib}.tree_predict(
- 'tree_out',
- 'dummy_dt_src',
- 'tree_predict_out',
- 'response'
-);
-
-SELECT * FROM tree_predict_out;
- """
- else:
- help_string = "No such option. Use {schema_madlib}.tree_predict('usage')"
- return help_string.format(schema_madlib=schema_madlib)
-# ------------------------------------------------------------
-
-
def _prune_and_cplist(schema_madlib, tree, cp, compute_cp_list=False):
""" Prune tree with given cost-complexity parameters
and return a list of cp values at which tree can be pruned
@@ -2425,87 +2193,274 @@ def _tree_train_grps_using_bins(
# ------------------------------------------------------------
-def _tree_rmse(schema_madlib, source_table, dependent_varname, prediction_table,
- pred_dep_name, id_col_name, grouping_cols, output_table,
- use_existing_tables=False, k=0, **kwargs):
- old_messages = plpy.execute("SELECT setting FROM pg_settings "
- "WHERE name = 'client_min_messages'")[0]['setting']
- plpy.execute('SET client_min_messages TO warning')
+def _tree_error(schema_madlib, source_table, dependent_varname,
+ prediction_table, pred_dep_name, id_col_name, grouping_cols,
+ output_table, is_classification,
+ use_existing_tables=False, k=0, **kwargs):
+ with MinWarning("warning"):
+ if use_existing_tables and table_exists(output_table):
+ # plpy.execute("truncate " + output_table)
+ header = "INSERT INTO " + output_table + " "
+ else:
+ header = "CREATE TABLE " + output_table + " AS "
+ if is_classification:
+ error_func = """
+ 1.0 * sum(CASE WHEN ({prediction_table}.{pred_dep_name} =
+ {source_table}.{dependent_varname})
+ THEN 0
+ ELSE 1
+ END) / count(*)
+ """.format(**locals())
+ else:
+ error_func = """
+ sqrt(avg(({prediction_table}.{pred_dep_name} -
+ {source_table}.{dependent_varname})^2
+ )
+ )
+ """.format(**locals())
+ grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols
+ grouping_col_str = '' if not grouping_cols else grouping_cols + ','
+
+ sql = header + """
+ SELECT
+ {grouping_col_str}
+ {error_func} as cv_error,
+ {k} as k
+ FROM {prediction_table}, {source_table}
+ WHERE {prediction_table}.{id_col_name} = {source_table}.{id_col_name}
+ {grouping_str}
+ """.format(**locals())
+ plpy.execute(sql)
+# ------------------------------------------------------------
+
+
+def tree_train_help_message(schema_madlib, message, **kwargs):
+ """ Help message for Decision Tree
+ """
+ if not message:
+ help_string = """
+------------------------------------------------------------
+ SUMMARY
+------------------------------------------------------------
+Functionality: Decision Tree
- fold = ", " + str(k) + " as k"
- if use_existing_tables and table_exists(output_table):
- # plpy.execute("truncate " + output_table)
- header = "INSERT INTO " + output_table + " "
- else:
- header = "CREATE TABLE " + output_table + " AS "
+Decision trees use a tree-based predictive model to
+predict the value of a target variable based on several input variables.
- grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols
- grouping_col_str = "" if not grouping_cols else grouping_cols + ','
+For more details on the function usage:
+ SELECT {schema_madlib}.tree_train('usage');
+For an example on using this function:
+ SELECT {schema_madlib}.tree_train('example');
+ """
+ elif message.lower().strip() in ['usage', 'help', '?']:
+ help_string = """
+------------------------------------------------------------
+ USAGE
+------------------------------------------------------------
+SELECT {schema_madlib}.tree_train(
+ 'training_table', -- Data table name
+ 'output_table', -- Table name to store the tree model
+ 'id_col_name', -- Row ID, used in tree_predict
+ 'dependent_variable', -- The column to fit
+ 'list_of_features', -- Comma separated column names to be
+ used as the predictors, can be '*'
+ to include all columns except the
+ dependent_variable
+ 'features_to_exclude', -- Comma separated column names to be
+ excluded if list_of_features is '*'
+ 'split_criterion', -- How to split a node, options are
+ 'gini', 'misclassification' and
+ 'entropy' for classification, and
+ 'mse' for regression.
+ 'grouping_cols', -- Comma separated column names used to
+ group the data. A decision tree model
+ will be created for each group. Default
+ is NULL
+ 'weights', -- A Column name containing weights for
+ each observation. Default is NULL
+ max_depth, -- Maximum depth of any node, default is 7
+ min_split, -- Minimum number of observations that must
+ exist in a node for a split to be
+ attemped, default is 20
+ min_bucket, -- Minimum number of observations in any
+ terminal node, default is min_split/3
+ n_bins, -- Number of bins to find possible node
+ split threshold values for continuous
+ variables, default is 20 (Must be greater than 1)
+ pruning_params, -- A comma-separated text containing
+ key=value pairs of parameters for pruning.
+ Parameters accepted:
+ 'cp' - complexity parameter with default=0.01,
+ 'n_folds' - number of cross-validation folds
+ with default value of 0 (= no cross-validation)
+ null_handling_params, -- A comma-separated text containing
+ key=value pairs of parameters for handling NULL values.
+ Parameters accepted:
+ 'max_surrogates' - Maximum number of surrogates to
+ compute for each split
+ 'null_as_category' - Boolean to indicate if
+ NULL should be treated as a special category
+ verbose -- Boolean, whether to print more info, default is False
+);
- sql = header + """
- SELECT
- {grouping_col_str}
- sqrt(avg(
- ({prediction_table}.{pred_dep_name} -
- {source_table}.{dependent_varname})^2
- )
- ) as cv_error
- {fold}
- FROM {prediction_table}, {source_table}
- WHERE {prediction_table}.{id_col_name} = {source_table}.{id_col_name}
- {grouping_str}
- """.format(output_table=output_table,
- prediction_table=prediction_table,
- pred_dep_name=pred_dep_name,
- source_table=source_table,
- dependent_varname=dependent_varname,
- grouping_col_str=grouping_col_str,
- grouping_str=grouping_str,
- id_col_name=id_col_name,
- fold=fold)
- plpy.execute(sql)
- plpy.execute('SET client_min_messages TO ' + old_messages)
-# ------------------------------------------------------------
+------------------------------------------------------------
+ OUTPUT
+------------------------------------------------------------
+The output table ('output_table' above) has the following columns (quoted items
+are of type TEXT):
+ <grouping columns> -- Grouping columns, only present when
+ 'grouping_cols' is not NULL or ''
+ tree -- The decision tree model as a binary string
+ cat_levels_in_text -- Distinct levels (casted to text) of all
+ categorical variables combined in a single array
+ cat_n_levels -- Number of distinct levels of all categorical variables
+ tree_depth -- Number of levels in the tree (root has level 0)
+ pruning_cp -- The cost-complexity parameter used for pruning
+ the trained tree(s). This would be different
+ from the input cp value if cross-validation is used.
+
+The output summary table ('output_table_summary') has the following columns:
+ 'method' -- Method name: 'tree_train'
+ 'source_table' -- Data table name
+ 'model_table' -- Tree model table name
+ 'id_col_name' -- Name of the 'id' column
+ is_classification -- Boolean value indicating if tree is classification or regression
+ 'dependent_varname' -- Response variable column name
+ 'independent_varnames' -- Comma-separated feature column names
+ 'cat_features' -- Comma-separated column names of categorical variables
+ 'con_features' -- Comma-separated column names of continuous variables
+ 'grouping_cols' -- Grouping column names
+ num_all_groups -- Number of groups
+ num_failed_groups -- Number of groups for which training failed
+ total_rows_processed -- Number of rows used in the model training
+ total_rows_skipped -- Number of rows skipped because NULL values
+ dependent_var_levels -- For classification, the distinct levels of
+ the dependent variable
+ dependent_var_type -- The type of dependent variable
+ input_cp -- The complexity parameter (cp) used for pruning the
+ trained tree(s) (before cross-validation is run)
+ independent_var_types -- The types of independent variables, comma-separated
+ k -- Number of folds (NULL if not using cross validation)
+ null_proxy -- String used as replacement for NULL values
+ (NULL if null_as_category = False)
+
+ """
+ elif message.lower().strip() in ['example', 'examples']:
+ help_string = """
+------------------------------------------------------------
+ EXAMPLE
+------------------------------------------------------------
+DROP TABLE IF EXISTS dummy_dt_con_src CASCADE;
+CREATE TABLE dummy_dt_con_src (
+ id INTEGER,
+ cat INTEGER[],
+ con FLOAT8[],
+ y FLOAT8
+);
+INSERT INTO dummy_dt_src VALUES
+(1, '{{0}}'::INTEGER[], ARRAY[0], 0.5),
+(2, '{{0}}'::INTEGER[], ARRAY[1], 0.5),
+(3, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
+(4, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
+(5, '{{0}}'::INTEGER[], ARRAY[4], 0.5),
+(6, '{{0}}'::INTEGER[], ARRAY[5], 0.1),
+(7, '{{0}}'::INTEGER[], ARRAY[6], 0.1),
+(8, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
+(9, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
+(10, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
+(11, '{{1}}'::INTEGER[], ARRAY[9], 0.1);
-def _tree_misclassified(schema_madlib, source_table, dependent_varname,
- prediction_table, pred_dep_name, id_col_name,
- grouping_cols, output_table, use_existing_tables=False,
- k=0, **kwargs):
- old_messages = plpy.execute(
- "SELECT setting FROM pg_settings WHERE name = 'client_min_messages'")[0]['setting']
- plpy.execute('SET client_min_messages TO warning')
+DROP TABLE IF EXISTS tree_out, tree_out_summary;
+SELECT madlib.tree_train(
+ 'dummy_dt_src',
+ 'tree_out',
+ 'id',
+ 'y',
+ 'cat, con',
+ '',
+ 'mse',
+ NULL::Text,
+ NULL::Text,
+ 3,
+ 2,
+ 1,
+ 5);
- fold = ", " + str(k) + " as k"
- if use_existing_tables and table_exists(output_table):
- # plpy.execute("truncate " + output_table)
- header = "INSERT INTO " + output_table + " "
+SELECT madlib.tree_display('tree_out');
+ """
else:
- header = "CREATE TABLE " + output_table + " AS "
+ help_string = "No such option. Use {schema_madlib}.tree_train('usage')"
+ return help_string.format(schema_madlib=schema_madlib)
+# ------------------------------------------------------------
- grouping_str = '' if not grouping_cols else "GROUP BY " + grouping_cols
- grouping_col_str = "" if not grouping_cols else grouping_cols + ','
- sql = header + """
- SELECT
- {grouping_col_str}
- 1.0 * sum(CASE WHEN ({prediction_table}.{pred_dep_name} =
- {source_table}.{dependent_varname})
- THEN 0 ELSE 1 END) / count(*) as cv_error
- {fold}
- FROM
- {prediction_table}, {source_table}
- WHERE
- {prediction_table}.{id_col_name} = {source_table}.{id_col_name}
- {grouping_str}
- """.format(output_table=output_table,
- prediction_table=prediction_table,
- pred_dep_name=pred_dep_name,
- source_table=source_table,
- dependent_varname=dependent_varname,
- grouping_col_str=grouping_col_str,
- grouping_str=grouping_str,
- id_col_name=id_col_name,
- fold=fold)
- plpy.execute(sql)
- plpy.execute('SET client_min_messages TO ' + old_messages)
+def tree_predict_help_message(schema_madlib, message, **kwargs):
+ """ Help message for Decision Tree predict
+ """
+ if not message:
+ help_string = """
+------------------------------------------------------------
+ SUMMARY
+------------------------------------------------------------
+Functionality: Decision Tree Prediction
+
+Prediction for a decision tree (trained using {schema_madlib}.tree_train) can
+be performed on a new data table.
+
+For more details on the function usage:
+ SELECT {schema_madlib}.tree_predict('usage');
+For an example on using this function:
+ SELECT {schema_madlib}.tree_predict('example');
+ """
+ elif message.lower().strip() in ['usage', 'help', '?']:
+ help_string = """
+------------------------------------------------------------
+ USAGE
+------------------------------------------------------------
+SELECT {schema_madlib}.tree_predict(
+ 'tree_model', -- Model table name (output of tree_train)
+ 'new_data_table', -- Prediction source table
+ 'output_table', -- Table name to store the predictions
+ 'type' -- Type of prediction output
+);
+
+Note: The 'new_data_table' should have the same 'id_col_name' column as used
+in the training function. This is used to corelate the prediction data row with
+the actual prediction in the output table.
+
+------------------------------------------------------------
+ OUTPUT
+------------------------------------------------------------
+The output table ('output_table' above) has the '<id_col_name>' column giving
+the 'id' for each prediction and the prediction columns for the response
+variable (also called as dependent variable).
+
+If prediction type = 'response', then the table has a single column with the
+prediction value of the response. The type of this column depends on the type
+of the response variable used during training.
+
+If prediction type = 'prob', then the table has multiple columns, one for each
+possible value of the response variable. The columns are labeled as
+'estimated_prob_<dep value>', where <dep value> represents for each value
+of the response.
+ """
+ elif message.lower().strip() in ['example', 'examples']:
+ help_string = """
+------------------------------------------------------------
+ EXAMPLE
+------------------------------------------------------------
+-- Assuming the example of tree_train() has been run
+SELECT {schema_madlib}.tree_predict(
+ 'tree_out',
+ 'dummy_dt_src',
+ 'tree_predict_out',
+ 'response'
+);
+
+SELECT * FROM tree_predict_out;
+ """
+ else:
+ help_string = "No such option. Use {schema_madlib}.tree_predict('usage')"
+ return help_string.format(schema_madlib=schema_madlib)
+# ------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/5b2ada5d/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
index 91e900d..c18fba9 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
@@ -1947,7 +1947,11 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_rmse(
grouping_cols TEXT,
output_table VARCHAR
) RETURNS VOID AS $$
- PythonFunction(recursive_partitioning, decision_tree, _tree_rmse)
+ PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+ decision_tree._tree_error(
+ schema_madlib, source_table, dependent_varname,
+ prediction_table, pred_dep_name, id_col_name, grouping_cols,
+ output_table, False)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -1962,7 +1966,11 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_rmse(
use_existing_tables BOOLEAN,
k INTEGER
) RETURNS VOID AS $$
- PythonFunction(recursive_partitioning, decision_tree, _tree_rmse)
+ PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+ decision_tree._tree_error(
+ schema_madlib, source_table, dependent_varname,
+ prediction_table, pred_dep_name, id_col_name, grouping_cols,
+ output_table, False, use_existing_tables, k)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
-------------------------------------------------------------------------
@@ -1977,7 +1985,11 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_misclassified(
grouping_cols TEXT,
output_table VARCHAR
) RETURNS VOID AS $$
- PythonFunction(recursive_partitioning, decision_tree, _tree_misclassified)
+ PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+ decision_tree._tree_error(
+ schema_madlib, source_table, dependent_varname,
+ prediction_table, pred_dep_name, id_col_name, grouping_cols,
+ output_table, True)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -1992,6 +2004,10 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._tree_misclassified(
use_existing_tables BOOLEAN,
k INTEGER
) RETURNS VOID AS $$
- PythonFunction(recursive_partitioning, decision_tree, _tree_misclassified)
+ PythonFunctionBodyOnly(recursive_partitioning, decision_tree)
+ decision_tree._tree_error(
+ schema_madlib, source_table, dependent_varname,
+ prediction_table, pred_dep_name, id_col_name, grouping_cols,
+ output_table, True, use_existing_tables, k)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
http://git-wip-us.apache.org/repos/asf/madlib/blob/5b2ada5d/src/ports/postgres/modules/utilities/validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index baefa5d..2b9c6d7 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -553,7 +553,7 @@ def _tbl_dimension_rownum(schema_madlib, tbl_source, col_ind_var):
SELECT array_upper({col_ind_var},1) AS dimension
FROM {tbl_source} LIMIT 1
""".format(tbl_source=tbl_source,
- col_ind_var=col_ind_var))[0]["dimension"]
+ col_ind_var=col_ind_var))[0]["dimension"]
# total row number of data source table
# The WHERE clause here ignores rows in the table that contain one or more
# NULLs in the independent variable (x). There is no NULL check made for
@@ -569,6 +569,7 @@ def _tbl_dimension_rownum(schema_madlib, tbl_source, col_ind_var):
return (dimension, row_num)
# ------------------------------------------------------------------------
+
def is_var_valid(tbl, var, order_by=None):
"""
Test whether the variable(s) is valid by actually selecting it from
@@ -585,7 +586,8 @@ def is_var_valid(tbl, var, order_by=None):
{order_by_str}
LIMIT 0
""".format(**locals()))
- except Exception:
+ except Exception as e:
+ plpy.warning(str(e))
return False
return True
# -------------------------------------------------------------------------