You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2016/01/15 01:53:39 UTC

incubator-madlib git commit: Correlation: Return columns sorted in ordinal position

Repository: incubator-madlib
Updated Branches:
  refs/heads/master deb175ab3 -> eea1f1f76


Correlation: Return columns sorted in ordinal position

JIRA: MADLIB-941

Couple of minor issues fixed here:
1. get_cols_and_types utility function returned columns
in an arbitrary order due to dictionary creation at the end. This has
been fixed by returning a list of tuples. An OrderedDict would be the
best choice here but some platforms are still on Python 2.6
2. Multiple modules that depended on this function had to be updated
either creating the dict in the calling function or using the list
instead of a dict.


Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/eea1f1f7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/eea1f1f7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/eea1f1f7

Branch: refs/heads/master
Commit: eea1f1f7699b713f4ac7e89cbae01d7030f21551
Parents: deb175a
Author: Rahul Iyer <ri...@pivotal.io>
Authored: Thu Jan 14 16:40:40 2016 -0800
Committer: Rahul Iyer <ri...@pivotal.io>
Committed: Thu Jan 14 16:44:56 2016 -0800

----------------------------------------------------------------------
 methods/svec_util/src/pg_gp/generate_svec.py_in |   4 +-
 .../modules/elastic_net/elastic_net.py_in       |   2 +-
 .../recursive_partitioning/decision_tree.py_in  |   9 +-
 .../recursive_partitioning/random_forest.py_in  |   9 +-
 .../postgres/modules/stats/correlation.py_in    | 128 +++++++++----------
 .../modules/utilities/validate_args.py_in       |  63 ++++++---
 6 files changed, 113 insertions(+), 102 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/eea1f1f7/methods/svec_util/src/pg_gp/generate_svec.py_in
----------------------------------------------------------------------
diff --git a/methods/svec_util/src/pg_gp/generate_svec.py_in b/methods/svec_util/src/pg_gp/generate_svec.py_in
index bc691ab..6a9e689 100644
--- a/methods/svec_util/src/pg_gp/generate_svec.py_in
+++ b/methods/svec_util/src/pg_gp/generate_svec.py_in
@@ -167,12 +167,12 @@ def _validate_args(schema_madlib, output_tbl, dictionary_tbl, dict_id_col,
     # table and term info column from documents table to be of any of the
     # respective required types.
     #
-    dict_col_type_dict = get_cols_and_types(dictionary_tbl)
+    dict_col_type_dict = dict(get_cols_and_types(dictionary_tbl))
     _assert(verify_type(dict_col_type_dict, dict_id_col, True),
         "Svec error: Unexpected type for column '%s' in dictionary table."
         " Should be int or bigint type" % dict_id_col)
 
-    doc_col_type_dict = get_cols_and_types(documents_tbl)
+    doc_col_type_dict = dict(get_cols_and_types(documents_tbl))
     _assert(verify_type(doc_col_type_dict, doc_term_info_col, False),
         "Svec error: Unexpected type for column '%s' in documents table."
         " Should be int, bigint, double precision or Array type" % doc_term_info_col)

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/eea1f1f7/src/ports/postgres/modules/elastic_net/elastic_net.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/elastic_net/elastic_net.py_in b/src/ports/postgres/modules/elastic_net/elastic_net.py_in
index ad16d54..21c581a 100644
--- a/src/ports/postgres/modules/elastic_net/elastic_net.py_in
+++ b/src/ports/postgres/modules/elastic_net/elastic_net.py_in
@@ -401,7 +401,7 @@ def analyze_input_str(schema_madlib, tbl_source,
 
     outstr_array = []
     if col_ind_var == "*":
-        col_types_dict = get_cols_and_types(tbl_source)
+        col_types_dict = dict(get_cols_and_types(tbl_source))
         cols = col_types_dict.keys()
 
         s = _string_to_array(excluded) if excluded is not None else []

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/eea1f1f7/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
index df97196..6fc8689 100644
--- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in
@@ -471,7 +471,7 @@ def get_grouping_array_str(table_name, grouping_cols, qualifier=None):
     else:
         qualifier_str = ''
 
-    all_cols_types = get_cols_and_types(table_name)
+    all_cols_types = dict(get_cols_and_types(table_name))
     grouping_cols_list = [col.strip() for col in grouping_cols.split(',')]
     grouping_cols_and_types = [(col, _get_col_value(all_cols_types, col))
                                for col in grouping_cols_list]
@@ -508,7 +508,7 @@ def _build_tree(schema_madlib, is_classification, split_criterion,
     with MinWarning(msg_level):
         plpy.notice("Building tree for cross validation")
         tree_states, bins, dep_list, n_rows = _get_tree_states(**locals())
-        all_cols_types = get_cols_and_types(training_table_name)
+        all_cols_types = dict(get_cols_and_types(training_table_name))
         n_all_rows = plpy.execute("select count(*) from " + training_table_name
                                   )[0]['count']
         cp = grp_key_to_cp.values()[0]
@@ -589,7 +589,7 @@ def tree_train(schema_madlib, training_table_name, output_table_name,
                 "Decision tree error: No feature is selected for the model.")
 
         # 2)
-        all_cols_types = get_cols_and_types(training_table_name)
+        all_cols_types = dict(get_cols_and_types(training_table_name))
         cat_features, con_features, boolean_cats = _classify_features(
             all_cols_types, features)
         # get all rows
@@ -1668,8 +1668,7 @@ def tree_predict(schema_madlib, model, source, output, pred_type='response',
     dep_type = summary_elements['dependent_var_type']
 
     # find which columns are of type boolean
-    boolean_cats = set([key for key, value in
-                        (get_cols_and_types(source)).iteritems()
+    boolean_cats = set([key for key, value in get_cols_and_types(source)
                         if value == 'boolean'])
 
     cat_features_str, con_features_str = get_feature_str(

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/eea1f1f7/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
index bd9ce70..a47dd9b 100644
--- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
+++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in
@@ -20,7 +20,7 @@ from utilities.validate_args import output_tbl_valid
 from utilities.validate_args import cols_in_tbl_valid
 from utilities.utilities import _assert
 from utilities.utilities import unique_string
-from utilities.utilities import add_postfix 
+from utilities.utilities import add_postfix
 from utilities.utilities import split_quoted_delimited_str
 from utilities.utilities import extract_keyvalue_params
 
@@ -317,7 +317,7 @@ def forest_train(
                         "Random forest error: Number of features to be selected is more "
                         "than the actual number of features.")
 
-                all_cols_types = get_cols_and_types(training_table_name)
+                all_cols_types = dict(get_cols_and_types(training_table_name))
                 cat_features, con_features, boolean_cats = _classify_features(
                     all_cols_types, features)
 
@@ -685,8 +685,7 @@ def forest_predict(schema_madlib, model, source, output, pred_type='response',
             "Random forest error: pred_type cannot be 'prob' for regression model.")
 
     # find which columns are of type boolean
-    boolean_cats = set([key for key, value in
-                        (get_cols_and_types(source)).iteritems()
+    boolean_cats = set([key for key, value in get_cols_and_types(source)
                         if value == 'boolean'])
 
     cat_features_str, con_features_str = get_feature_str(
@@ -976,7 +975,7 @@ def _calculate_oob_prediction(
                 NULL::float8[] AS con_index_distributions
             FROM {cat_features_info_table}
         """.format(**locals())
-        
+
     plpy.notice("sql_create_oob_var_dist_view : " + str(sql_create_oob_var_dist_view))
     plpy.execute(sql_create_oob_var_dist_view)
 

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/eea1f1f7/src/ports/postgres/modules/stats/correlation.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/stats/correlation.py_in b/src/ports/postgres/modules/stats/correlation.py_in
index 6bc1b5a..83604f4 100644
--- a/src/ports/postgres/modules/stats/correlation.py_in
+++ b/src/ports/postgres/modules/stats/correlation.py_in
@@ -5,11 +5,13 @@
 
 @namespace correlation
 """
-import plpy
 from time import time
+
+import plpy
+from utilities.control import MinWarning
 from utilities.utilities import unique_string, add_postfix
 from utilities.validate_args import get_cols_and_types
-from utilities.control import MinWarning
+from utilities.validate_args import input_tbl_valid, output_tbl_valid, cols_in_tbl_valid
 
 
 def correlation(schema_madlib, source_table, output_table,
@@ -30,30 +32,40 @@ def correlation(schema_madlib, source_table, output_table,
     Returns:
         Tuple (output table name, number of columns, time for computation)
     """
-    _validate_corr_arg(source_table, output_table)
-    _numeric_column_names, _nonnumeric_column_names = _get_numeric_columns(source_table)
-    _target_cols = _analyze_target_cols(target_cols)
-    _nonexisting_target_cols = None
-    if _target_cols:
-        # prune all non-numeric column types from target columns
-        # ensure all column names are unique since we'll use them as output
-        # column names
-        _existing_target_cols = list(set(column for column in _target_cols
-                                         if column in _numeric_column_names))
-        _nonexisting_target_cols = list(set(_target_cols) -
-                                        set(_existing_target_cols) -
-                                        set(_nonnumeric_column_names))
-        _nonnumeric_target_cols = list(set(_target_cols) &
-                                       set(_nonnumeric_column_names))
-    else:
-        # if target_cols not provided then all numeric columns are to be targeted
-        _existing_target_cols = _numeric_column_names[:]
-        _nonnumeric_target_cols = _nonnumeric_column_names[:]
+    with MinWarning("info" if verbose else "error"):
+        _validate_corr_arg(source_table, output_table)
+        _numeric_column_names, _nonnumeric_column_names = _get_numeric_columns(source_table)
+        _target_cols = _analyze_target_cols(source_table, target_cols)
+        if _target_cols:
+            # prune all non-numeric column types from target columns
+            _existing_target_cols = []
+            # we create a copy using a set to efficiently check membership (see below)
+            _existing_target_cols_check = set()
+            _nonexisting_target_cols = []
+            _nonnumeric_target_cols = []
+            _numeric_column_names = set(_numeric_column_names)
+            _nonnumeric_column_names = set(_nonnumeric_column_names)
+            for col in _target_cols:
+                if col in _numeric_column_names:
+                    # efficiently check membership using the set
+                    # ensure all column names are unique since they're output column names
+                    if col not in _existing_target_cols_check:
+                        _existing_target_cols.append(col)
+                        _existing_target_cols_check.add(col)
+                elif col in _nonnumeric_column_names:
+                    _nonnumeric_target_cols.append(col)
+                else:
+                    _nonexisting_target_cols.append(col)
+        else:
+            # if target_cols not provided then all numeric columns are to be targeted
+            _existing_target_cols = _numeric_column_names[:]
+            _nonnumeric_target_cols = _nonnumeric_column_names[:]
+            _nonexisting_target_cols = []
 
-    if len(_existing_target_cols) == 0:
-        plpy.error("Correlation error: No numeric column found in the target list.")
-    if len(_existing_target_cols) == 1:
-        plpy.error("Correlation error: Only one numeric column found in the target list.")
+        if not _existing_target_cols:
+            plpy.error("Correlation error: No numeric column found in the target list.")
+        if len(_existing_target_cols) == 1:
+            plpy.error("Correlation error: Only one numeric column found in the target list.")
 
     # ---- Output message ----
     output_text_mesasge = "Summary for 'correlation' function"
@@ -66,16 +78,14 @@ def correlation(schema_madlib, source_table, output_table,
 
     output_text_mesasge += ("\n Producing correlation for columns: {0}".
                             format(str(_existing_target_cols)))
-    plpy.info(output_text_mesasge)
+    plpy.notice(output_text_mesasge)
     # ---- Output message ----
 
     return _populate_output_table(schema_madlib, source_table, output_table,
                                   _existing_target_cols, get_cov, verbose)
+# ------------------------------------------------------------------------------
 
 
-# -----------------------------------------------------------------------
-# Argument validation function
-# -----------------------------------------------------------------------
 def _validate_corr_arg(source_table, output_table):
     """
     Validates all arguments and raises an error if there is an invalid argument
@@ -88,28 +98,12 @@ def _validate_corr_arg(source_table, output_table):
     Returns:
         True if all arguments are valid
     """
-    if not source_table or source_table.strip() == '':
-        plpy.error("""
-            Correlation error: Invalid source table name""")
-    try:
-        plpy.execute("SELECT '{0}'::regclass::oid\
-                                    ".format(source_table))[0]['oid']
-    except:
-        plpy.error("Correlation error:  Relation '{0}' does not exist\
-                        ".format(source_table))
-    rowcount = plpy.execute("""
-        SELECT count(*) FROM {0}""".format(source_table))[0]['count']
-    if rowcount == 0:
-        plpy.error("Relation '{0}' is empty".format(source_table))
-
-    if not output_table or output_table.strip() == '':
-        plpy.error("Correlation error: Invalid output table name")
-    return True
-
-
-# -----------------------------------------------------------------------
-# Get all numeric column names in source table
-# -----------------------------------------------------------------------
+    input_tbl_valid(source_table, "Correlation")
+    output_tbl_valid(output_table, "Correlation")
+    output_tbl_valid(add_postfix(output_table, "_summary"), "Correlation")
+# ------------------------------------------------------------------------------
+
+
 def _get_numeric_columns(source_table):
     """
     Returns all column names for numeric type columns in a relation
@@ -124,33 +118,30 @@ def _get_numeric_columns(source_table):
     # retrieve the numeric columns
     numeric_types = ('smallint', 'integer', 'bigint',
                      'real', 'numeric', 'double precision')
-    all_col_types = get_cols_and_types(source_table)
-    all_col_names = set(all_col_types.keys())
-
-    numeric_col_names = [c for c in all_col_names
-                         if all_col_types[c] in numeric_types]
-    nonnum_col_names = list(all_col_names - set(numeric_col_names))
-
+    numeric_col_names = []
+    nonnum_col_names = []
+    for col_name, col_type in get_cols_and_types(source_table):
+        if col_type in numeric_types:
+            numeric_col_names.append(col_name)
+        else:
+            nonnum_col_names.append(col_name)
     return (numeric_col_names, nonnum_col_names)
+# ------------------------------------------------------------------------------
 
 
-# -----------------------------------------------------------------------
-# Input parameter checks and edits
-# -----------------------------------------------------------------------
-def _analyze_target_cols(target_cols):
+def _analyze_target_cols(source_table, target_cols):
     """
     Analyzes target_cols string input and converts it to a list
     """
     if not target_cols or target_cols.strip() in ('', '*'):
         target_cols = None
     else:
-        target_cols = target_cols.replace(' ', '').split(',')
+        target_cols = [i.strip() for i in target_cols.split(',')]
+        cols_in_tbl_valid(source_table, target_cols, "Correlation")
     return target_cols
+# ------------------------------------------------------------------------------
 
 
-# -----------------------------------------------------------------------
-# Create and populate output table
-# -----------------------------------------------------------------------
 def _populate_output_table(schema_madlib, source_table, output_table,
                            col_names, get_cov=True, verbose=False):
     """
@@ -265,11 +256,9 @@ def _populate_output_table(schema_madlib, source_table, output_table,
         plpy.execute("DROP TABLE {temp_table}".format(**locals()))
         end = time()
         return (output_table, len(col_names), end - start)
+# ------------------------------------------------------------------------------
 
 
-# -----------------------------------------------------------------------
-# Help messages
-# -----------------------------------------------------------------------
 def correlation_help_message(schema_madlib, message, cov=False, **kwargs):
     """
     Given a help string, provide usage information
@@ -375,3 +364,4 @@ OR
     SELECT {schema_madlib}.covariance('example');
 
             """.format(schema_madlib=schema_madlib)
+# ------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/eea1f1f7/src/ports/postgres/modules/utilities/validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index ed5343f..7093122 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -107,7 +107,7 @@ def table_exists(tbl, only_first_schema=False):
     schema_str, table = _get_table_schema_names(tbl, only_first_schema)
     if schema_str and table:
         schema_expr = "LIKE 'pg_temp%'" if schema_str == "('pg_temp')" \
-                else 'IN {0}'.format(schema_str)
+            else 'IN {0}'.format(schema_str)
         does_table_exist = plpy.execute(
             """
             SELECT EXISTS(
@@ -152,10 +152,10 @@ def rename_table(schema_madlib, orig_name, new_name):
     if len(orig_names_split) > 1:
         orig_table_schema = orig_names_split[0]
     else:
-        ## we need to get the schema name of the original table if we are
-        ## to change the schema of the new table. This is to ensure that we
-        ## change the schema of the correct table in case there are multiple
-        ## tables with the same new name.
+        # we need to get the schema name of the original table if we are
+        # to change the schema of the new table. This is to ensure that we
+        # change the schema of the correct table in case there are multiple
+        # tables with the same new name.
         orig_table_schema = get_first_schema(orig_name)
 
     if orig_table_schema is None:
@@ -167,7 +167,7 @@ def rename_table(schema_madlib, orig_name, new_name):
 
     if new_table_schema:
         if new_table_schema != orig_table_schema:
-            ## set schema only if a change in schema is required
+            # set schema only if a change in schema is required
             before_schema_string = "{0}.{1}".format(orig_table_schema,
                                                     new_table_name)
             plpy.execute("""ALTER TABLE {new_table}
@@ -201,16 +201,16 @@ def get_first_schema(table_name):
     elif len(names) == 2:
         return _unquote_name(names[0])
 
-    ## create a list of schema names in search path
-    ## _string_to_array is used for GPDB versions less than 4.2 where an array
-    ## is returned to Python as a string
+    # create a list of schema names in search path
+    # _string_to_array is used for GPDB versions less than 4.2 where an array
+    # is returned to Python as a string
     current_schemas = _string_to_array(plpy.execute(
         "SELECT current_schemas(True) AS cs")[0]["cs"])
 
     if not current_schemas:
         return None
 
-    ## get all schemas that contain a table with this name
+    # get all schemas that contain a table with this name
     schemas_w_table = _string_to_array(plpy.execute(
         """SELECT array_agg(table_schema::text) AS schemas
            FROM information_schema.tables
@@ -221,11 +221,11 @@ def get_first_schema(table_name):
         return None
 
     for each_schema in current_schemas:
-    ## get the first schema in search path that contains the table
+        # get the first schema in search path that contains the table
         if each_schema in schemas_w_table:
             return each_schema
 
-    ## None of the schemas in search path have the table
+    # None of the schemas in search path have the table
     return None
 # -------------------------------------------------------------------------
 
@@ -254,7 +254,7 @@ def get_cols(tbl, *args, **kwargs):
         plpy.error('Input error: Table name (NULL) is invalid')
 
     sql_string = """SELECT array_agg(quote_ident(attname)::varchar
-                                        ORDER BY attnum) AS cols
+                                     ORDER BY attnum) AS cols
                     FROM pg_attribute
                     WHERE attrelid = '{tbl}'::regclass
                       AND NOT attisdropped
@@ -297,8 +297,8 @@ def get_cols_and_types(tbl):
                         """.format(tbl=tbl))
     schema = row[0]['table_schema']
     table = row[0]['table_name']
-    sql_string = """SELECT array_agg(quote_ident(column_name)::varchar) AS cols,
-                           array_agg(data_type::varchar) AS types
+    sql_string = """SELECT array_agg(quote_ident(column_name)::varchar ORDER BY ordinal_position) AS cols,
+                           array_agg(data_type::varchar ORDER BY ordinal_position) AS types
                     FROM information_schema.columns
                     WHERE table_name = '{table}'
                     AND table_schema = '{schema}'
@@ -306,7 +306,7 @@ def get_cols_and_types(tbl):
     result = plpy.execute(sql_string)[0]
     col_names = _string_to_array(result['cols'])
     col_types = _string_to_array(result['types'])
-    return dict(zip(col_names, col_types))
+    return list(zip(col_names, col_types))
 # -------------------------------------------------------------------------
 
 
@@ -350,6 +350,26 @@ def columns_exist_in_table(tbl, cols, schema_madlib="madlib"):
 # -------------------------------------------------------------------------
 
 
+def columns_missing_from_table(tbl, cols):
+    """ Get which columns are not present in a given table
+
+    Args:
+        @param tbl Name of source table
+        @param cols Iterable containing column names
+
+    Returns:
+        True if all columns in 'cols' exist in source table else False
+    """
+    if not cols:
+        return []
+    existing_cols = set(_unquote_name(i) for i in get_cols(tbl))
+    # column is considered missing if the name is invalid (None or empty) or
+    #  if the column is not present in the table
+    return [col for col in cols
+            if not col or _unquote_name(col) not in existing_cols]
+# -------------------------------------------------------------------------
+
+
 def is_col_array(tbl, col):
     """
     Return True if the column is of an array datatype
@@ -513,10 +533,13 @@ def cols_in_tbl_valid(tbl, cols, module):
     for c in cols:
         if c is None or c.strip() == '':
             plpy.error("{module} error: NULL/empty column name!".format(**locals()))
-    if not columns_exist_in_table(tbl, cols):
-        for c in cols:
-            if not columns_exist_in_table(tbl, [c]):
-                plpy.error("{module} error: Column '{c}' does not exist in table '{tbl}'!".format(**locals()))
+    missing_cols = columns_missing_from_table(tbl, cols)
+
+    # FIXME: still printing just 1 column name for backwards compatibility
+    # this should be changed to print all missing columns
+    if missing_cols:
+        c = missing_cols[0]
+        plpy.error("{module} error: Column '{c}' does not exist in table '{tbl}'!".format(**locals()))
 # -------------------------------------------------------------------------