You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2018/04/05 22:48:12 UTC

[1/3] madlib git commit: MiniBatch Pre-Processor: Add support for grouping

Repository: madlib
Updated Branches:
  refs/heads/master e3d2eee9a -> ab83c95be


MiniBatch Pre-Processor: Add support for grouping

This commit enables grouping for the minibatch preprocessor module.

Other changes
1. Added install check test for special chars.
2. Improved error messages and created a reusable function for
testing column dimension in install check.
3. Add a new optional flag to utils_ind_var_scales_grouping so as to
create a persistent x_mean table that will be reused as the
standardization table by the preprocessor module.

Closes #254

Co-authored-by: Jingyi Mei <jm...@pivotal.io>


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/63261dfb
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/63261dfb
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/63261dfb

Branch: refs/heads/master
Commit: 63261dfbea7d0ba60437a5cc8afc2c3991bfe3e7
Parents: e3d2eee
Author: Nikhil Kak <nk...@pivotal.io>
Authored: Fri Mar 23 11:29:07 2018 -0700
Committer: Nandish Jayaram <nj...@apache.org>
Committed: Thu Apr 5 15:45:38 2018 -0700

----------------------------------------------------------------------
 .../modules/convex/utils_regularization.py_in   |  15 +-
 .../utilities/mean_std_dev_calculator.py_in     |  23 ++-
 .../utilities/minibatch_preprocessing.py_in     | 140 +++++++++----
 .../utilities/minibatch_preprocessing.sql_in    |  14 +-
 .../test/minibatch_preprocessing.sql_in         | 203 ++++++++++++-------
 .../test_minibatch_preprocessing.py_in          |  62 +++---
 .../test/unit_tests/test_utilities.py_in        | 110 +++++-----
 .../postgres/modules/utilities/utilities.py_in  |   9 +-
 8 files changed, 367 insertions(+), 209 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/convex/utils_regularization.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/utils_regularization.py_in b/src/ports/postgres/modules/convex/utils_regularization.py_in
index f56a042..9fc3e8f 100644
--- a/src/ports/postgres/modules/convex/utils_regularization.py_in
+++ b/src/ports/postgres/modules/convex/utils_regularization.py_in
@@ -70,8 +70,9 @@ def utils_ind_var_scales(tbl_data, col_ind_var, dimension, schema_madlib,
 # ========================================================================
 
 def utils_ind_var_scales_grouping(tbl_data, col_ind_var, dimension,
-                                    schema_madlib, grouping_col, x_mean_table,
-                                    set_zero_std_to_one=False):
+                                  schema_madlib, grouping_col, x_mean_table,
+                                  set_zero_std_to_one=False,
+                                  create_temp_table=True):
     """
     The mean and standard deviation for each dimension of an array stored in
     a column. Creates a table containing the mean (array) and std of each
@@ -83,8 +84,10 @@ def utils_ind_var_scales_grouping(tbl_data, col_ind_var, dimension,
         schema_madlib,
         grouping_col,
         x_mean_table,
-        set_zero_std_to_one (optional, default is False. If set to true
+        set_zero_std_to_one: (optional, default is False. If set to true
                      0.0 standard deviation values will be set to 1.0)
+        create_temp_table: If set to false, create a persistent instead of a
+                          temp table, else create a temp table for x_mean
 
     Returns:
         Dictionary with keys 'mean' and 'std' each with a value of an array of
@@ -102,10 +105,12 @@ def utils_ind_var_scales_grouping(tbl_data, col_ind_var, dimension,
     else:
         scaling_uda_name = 'utils_var_scales'
     group_col = _cast_if_null(grouping_col, unique_string('grp_col'))
+    create_table_command = "CREATE TEMP TABLE" if create_temp_table else \
+        "CREATE TABLE"
     x_scales = plpy.execute(
         """
-        CREATE TEMP TABLE {x_mean_table} AS
-        SELECT (f).*, {group_col}
+        {create_table_command} {x_mean_table} AS
+        SELECT {group_col}, (f).*
         FROM (
             SELECT {group_col},
                 {schema_madlib}.__utils_var_scales_result(

http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in b/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in
index e2a1c4f..852c56c 100644
--- a/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in
+++ b/src/ports/postgres/modules/utilities/mean_std_dev_calculator.py_in
@@ -24,8 +24,9 @@
 @namespace utilities
 
 """
-
+import plpy
 from convex.utils_regularization import utils_ind_var_scales
+from convex.utils_regularization import utils_ind_var_scales_grouping
 from utilities import _array_to_string
 
 m4_changequote(`<!', `!>')
@@ -40,15 +41,27 @@ class MeanStdDevCalculator:
         self.dimension = dimension
 
     def get_mean_and_std_dev_for_ind_var(self):
-        set_zero_std_to_one = True
-
         x_scaled_vals = utils_ind_var_scales(self.source_table,
                                              self.indep_var_array_str,
                                              self.dimension,
                                              self.schema_madlib,
-                                             None, # do not dump the output to a temp table
-                                             set_zero_std_to_one)
+                                             x_mean_table = None, # do not dump the output to a temp table
+                                             set_zero_std_to_one=True)
         x_mean_str = _array_to_string(x_scaled_vals["mean"])
         x_std_str = _array_to_string(x_scaled_vals["std"])
 
+        if not x_mean_str or not x_std_str:
+            plpy.error("mean/stddev for the independent variable"
+                       "cannot be null")
+
         return x_mean_str, x_std_str
+
+    def create_mean_std_table_for_ind_var_grouping(self, x_mean_table, grouping_cols):
+        utils_ind_var_scales_grouping(self.source_table,
+                                      self.indep_var_array_str,
+                                      self.dimension,
+                                      self.schema_madlib,
+                                      grouping_cols,
+                                      x_mean_table,
+                                      set_zero_std_to_one = True,
+                                      create_temp_table = False)

http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index f3766d7..bb9fddd 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -33,6 +33,7 @@ from utilities import is_psql_numeric_type
 from utilities import is_string_formatted_as_array_expression
 from utilities import py_list_to_sql_string
 from utilities import split_quoted_delimited_str
+from utilities import unique_string
 from utilities import _string_to_array
 from utilities import validate_module_input_params
 from mean_std_dev_calculator import MeanStdDevCalculator
@@ -53,13 +54,15 @@ class MiniBatchPreProcessor:
     source table into one row based on the buffer size
     """
     def __init__(self, schema_madlib, source_table, output_table,
-                  dependent_varname, independent_varname, buffer_size, **kwargs):
+                  dependent_varname, independent_varname, grouping_cols,
+                  buffer_size, **kwargs):
         self.schema_madlib = schema_madlib
         self.source_table = source_table
         self.output_table = output_table
         self.dependent_varname = dependent_varname
         self.independent_varname = independent_varname
         self.buffer_size = buffer_size
+        self.grouping_cols = grouping_cols
 
         self.module_name = "minibatch_preprocessor"
         self.output_standardization_table = add_postfix(self.output_table,
@@ -83,8 +86,8 @@ class MiniBatchPreProcessor:
                                              self.source_table,
                                              dep_var_array_str,
                                              indep_var_array_str,
+                                             self.grouping_cols,
                                              self.output_standardization_table)
-        standardize_query = standardizer.get_query_for_standardizing()
 
         num_rows_processed, num_missing_rows_skipped = self.\
                                                 _get_skipped_rows_processed_count(
@@ -119,13 +122,25 @@ class MiniBatchPreProcessor:
         # This ID is the unique row id that get assigned to each row after
         # preprocessing
         unique_row_id = "__id__"
+        standardize_query = standardizer.get_query_for_standardizing()
+
+        partition_by = ''
+        grouping_cols_select_col = ''
+        grouping_cols_group_by = ''
+        if self.grouping_cols:
+            partition_by = 'PARTITION BY {0}'.format(self.grouping_cols)
+            grouping_cols_select_col = self.grouping_cols + ','
+            grouping_cols_group_by = ',' + self.grouping_cols
+
         sql = """
             CREATE TABLE {output_table} AS
             SELECT {row_id},
+                   {grouping_cols_select_col}
                    {schema_madlib}.matrix_agg({dep_colname}) as {dep_colname},
                    {schema_madlib}.matrix_agg({ind_colname}) as {ind_colname}
             FROM (
-                SELECT (row_number() OVER (ORDER BY random()) - 1) / {buffer_size}
+                SELECT (row_number() OVER ({partition_by} ORDER BY random()) - 1)
+                        / {buffer_size}
                             as {row_id}, * FROM
                 (
                     {standardize_query}
@@ -133,7 +148,7 @@ class MiniBatchPreProcessor:
                  WHERE NOT {schema_madlib}.array_contains_null({dep_colname})
                  AND NOT {schema_madlib}.array_contains_null({ind_colname})
             ) sub_query_2
-            GROUP BY {row_id}
+            GROUP BY {row_id} {grouping_cols_group_by}
             {distributed_by_clause}
             """.format(
             schema_madlib=self.schema_madlib,
@@ -150,9 +165,9 @@ class MiniBatchPreProcessor:
             **locals())
         plpy.execute(sql)
 
-
         standardizer.create_output_standardization_table()
         MiniBatchSummarizer.create_output_summary_table(
+            self.output_summary_table,
             self.source_table,
             self.output_table,
             self.dependent_varname,
@@ -161,7 +176,8 @@ class MiniBatchPreProcessor:
             dep_var_classes_str,
             num_rows_processed,
             num_missing_rows_skipped,
-            self.output_summary_table)
+            self.grouping_cols
+            )
 
     def _validate_minibatch_preprocessor_params(self):
         # Test if the independent variable can be typecasted to a double
@@ -174,6 +190,7 @@ class MiniBatchPreProcessor:
         validate_module_input_params(self.source_table, self.output_table,
                                      typecasted_ind_varname,
                                      self.dependent_varname, self.module_name,
+                                     self.grouping_cols,
                                      [self.output_summary_table,
                                       self.output_standardization_table])
 
@@ -313,18 +330,18 @@ class MiniBatchStandardizer:
     3. Creating the output standardization table
     """
     def __init__(self, schema_madlib, source_table, dep_var_array_str,
-                 indep_var_array_str, output_standardization_table):
+                 indep_var_array_str, grouping_cols,
+                 output_standardization_table):
         self.schema_madlib = schema_madlib
         self.source_table = source_table
         self.dep_var_array_str = dep_var_array_str
         self.indep_var_array_str = indep_var_array_str
+        self.grouping_cols = grouping_cols
         self.output_standardization_table = output_standardization_table
 
+        self.x_mean_table = unique_string(desp='x_mean_table')
         self.x_mean_str = None
         self.x_std_dev_str = None
-        self.source_table_row_count = 0
-        self.grouping_cols = "NULL"
-        self.independent_var_dimension = None
         self._calculate_mean_and_std_dev_str()
 
     def _calculate_mean_and_std_dev_str(self):
@@ -338,15 +355,28 @@ class MiniBatchStandardizer:
                                           self.source_table,
                                           self.indep_var_array_str,
                                           self.independent_var_dimension)
-
-        self.x_mean_str, self.x_std_dev_str = calculator.\
-                                              get_mean_and_std_dev_for_ind_var()
-
-        if not self.x_mean_str or not self.x_std_dev_str:
-            plpy.error("mean/stddev for the independent variable"
-                       "cannot be null")
+        """
+        For grouping, we have to create a temp mean table because we have
+        to join the mean table and the source table by grouping cols. It's
+        easier to call utils_normalize_data with a table instead of storing this
+        information in memory in a data structure.
+        When if there is no grouping, a simple python string is enough to
+        store the mean and std_dev.
+        """
+        if self.grouping_cols:
+            calculator.create_mean_std_table_for_ind_var_grouping(
+                self.x_mean_table, self.grouping_cols)
+        else:
+            self.x_mean_str, self.x_std_dev_str = calculator.\
+                                            get_mean_and_std_dev_for_ind_var()
 
     def get_query_for_standardizing(self):
+        if self.grouping_cols:
+            return self._get_query_for_standardizing_with_grouping()
+        else:
+            return self._get_query_for_standardizing_without_grouping()
+
+    def _get_query_for_standardizing_without_grouping(self):
         query="""
         SELECT
         {dep_var_array_str} as {dep_colname},
@@ -367,33 +397,66 @@ class MiniBatchStandardizer:
             x_std_dev_str = self.x_std_dev_str)
         return query
 
-    def create_output_standardization_table(self):
-        query = """
-        CREATE TABLE {output_standardization_table} AS
-        select {grouping_cols}::TEXT AS grouping_cols,
-        '{x_mean_str}'::double precision[] AS mean,
-        '{x_std_dev_str}'::double precision[] AS std
+    def _get_query_for_standardizing_with_grouping(self):
+        query="""
+        SELECT
+        {dep_var_array_str} as {dep_colname},
+        {schema_madlib}.utils_normalize_data
+        (
+            {indep_var_array_str},__x__.mean::double precision[], __x__.std::double precision[]
+        ) as {ind_colname},
+        {source_table}.{grouping_cols}
+        FROM
+        {source_table} INNER JOIN {x_mean_table} AS __x__ ON  {source_table}.{grouping_cols} = __x__.{grouping_cols}
         """.format(
-        output_standardization_table = self.output_standardization_table,
-        grouping_cols = self.grouping_cols,
-        x_mean_str = self.x_mean_str,
-        x_std_dev_str = self.x_std_dev_str)
+            source_table = self.source_table,
+            schema_madlib = self.schema_madlib,
+            dep_var_array_str = self.dep_var_array_str,
+            indep_var_array_str = self.indep_var_array_str,
+            dep_colname = MINIBATCH_OUTPUT_DEPENDENT_COLNAME,
+            ind_colname = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME,
+            grouping_cols = self.grouping_cols,
+            x_mean_table = self.x_mean_table,
+            **locals())
+        return query
+
+    def create_output_standardization_table(self):
+        if self.grouping_cols:
+            query = """
+            ALTER TABLE {x_mean_table} RENAME TO {output_standardization_table}
+            """.format(
+            x_mean_table = self.x_mean_table,
+            output_standardization_table = self.output_standardization_table)
+        else:
+            query = """
+            CREATE TABLE {output_standardization_table} AS
+            select '{x_mean_str}'::double precision[] AS mean,
+            '{x_std_dev_str}'::double precision[] AS std
+            """.format(
+            output_standardization_table = self.output_standardization_table,
+            x_mean_str = self.x_mean_str,
+            x_std_dev_str = self.x_std_dev_str)
+
         plpy.execute(query)
 
 class MiniBatchSummarizer:
     @staticmethod
-    def create_output_summary_table(source_table, output_table,
-                                    dep_var_array_str, indep_var_array_str,
-                                    buffer_size, class_values,
-                                    num_rows_processed,
-                                    num_missing_rows_skipped,
-                                    output_summary_table):
+    def create_output_summary_table(output_summary_table, source_table,
+                                    output_table, dep_var_array_str,
+                                    indep_var_array_str, buffer_size,
+                                    class_values, num_rows_processed,
+                                    num_missing_rows_skipped, grouping_cols):
+        # 1. All the string columns are surrounded by "$$" to take care of
+        #    special characters in the column name.
+        # 2. We have to typecast all the string column names to ::TEXT because
+        #    otherwise there is a warning from psql
+        #    WARNING: column "independent_varname" has type "unknown"
         query = """
             CREATE TABLE {output_summary_table} AS
-            SELECT '{source_table}'::TEXT AS source_table,
-            '{output_table}'::TEXT AS output_table,
-            '{dependent_varname}'::TEXT AS dependent_varname,
-            '{independent_varname}'::TEXT AS independent_varname,
+            SELECT $${source_table}$$::TEXT AS source_table,
+            $${output_table}$$::TEXT AS output_table,
+            $${dependent_varname}$$::TEXT AS dependent_varname,
+            $${independent_varname}$$::TEXT AS independent_varname,
             {buffer_size} AS buffer_size,
             {class_values} AS class_values,
             {num_rows_processed} AS num_rows_processed,
@@ -408,7 +471,8 @@ class MiniBatchSummarizer:
                    class_values = class_values,
                    num_rows_processed = num_rows_processed,
                    num_missing_rows_skipped = num_missing_rows_skipped,
-                   grouping_cols = "NULL")
+                   grouping_cols = "$$" + grouping_cols + "$$"
+                                    if grouping_cols else "NULL")
         plpy.execute(query)
 
 class MiniBatchBufferSizeCalculator:

http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in
index 01d91e5..6a48c4f 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.sql_in
@@ -187,6 +187,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor(
     output_table            VARCHAR,
     dependent_varname       VARCHAR,
     independent_varname     VARCHAR,
+    grouping_cols           VARCHAR,
     buffer_size             INTEGER
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(utilities, minibatch_preprocessing)
@@ -199,9 +200,20 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor(
     source_table            VARCHAR,
     output_table            VARCHAR,
     dependent_varname       VARCHAR,
+    independent_varname     VARCHAR,
+    grouping_cols           VARCHAR
+) RETURNS VOID AS $$
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor($1, $2, $3, $4, $5, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.minibatch_preprocessor(
+    source_table            VARCHAR,
+    output_table            VARCHAR,
+    dependent_varname       VARCHAR,
     independent_varname     VARCHAR
 ) RETURNS VOID AS $$
-  SELECT MADLIB_SCHEMA.minibatch_preprocessor($1, $2, $3, $4, NULL);
+  SELECT MADLIB_SCHEMA.minibatch_preprocessor($1, $2, $3, $4, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
index d49b66f..7eaafb6 100644
--- a/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
+++ b/src/ports/postgres/modules/utilities/test/minibatch_preprocessing.sql_in
@@ -18,6 +18,40 @@
  * under the License.
  *
  *//* ----------------------------------------------------------------------- */
+
+-- The following function is used to assert if the minibatch preprocessor packs the
+-- data in expected way so that dependent variable and independent variable have the
+-- correct dimension.
+
+-- column_name    : allows 'dependent_varname' or 'independent_varname' as
+--                  input to indicate which variable to test against
+-- dimension      : '1' is used to test how many rows(records) get packed into one super row,
+--                  '2' is used to test how many columns in one records get packed into an array
+-- expected_result: a string of array with the expected number for row/col number in
+--                  ascending order, e.g. '{2,4,4}'
+-- grp            : grouping columns. "NULL" if there is no grouping.
+
+-- See more examples in this file to know how it works.
+
+CREATE OR REPLACE FUNCTION assert_col_dimension(column_name VARCHAR, dimension int, expected_result text, grp text)
+RETURNS void AS
+$$
+    DECLARE
+		qry text;
+    result text;
+    BEGIN
+        IF grp is NULL THEN
+          qry := 'select array_agg(row_count order by row_count asc) from (select array_upper(' || column_name ||' ,' || dimension ||') as row_count from minibatch_preprocessing_out order by row_count asc) s';
+        ELSE
+          qry := 'select array_agg(row_count order by ' || grp || ', row_count asc) from (select array_upper(' || column_name ||' ,' || dimension ||') as row_count,' || grp || ' from minibatch_preprocessing_out order by ' || grp || ' , row_count asc) s';
+        END IF;
+				EXECUTE qry into result;
+        IF result != expected_result THEN
+          raise exception 'Dependent/Independent Varaiable dimension check failed. Actual: % Expected %', result, expected_result;
+        END IF;
+    END;
+$$ LANGUAGE plpgsql;
+
 DROP TABLE IF EXISTS minibatch_preprocessing_input;
 CREATE TABLE minibatch_preprocessing_input(
     sex TEXT,
@@ -46,50 +80,57 @@ INSERT INTO minibatch_preprocessing_input(id,sex,length,diameter,height,whole,sh
 -- no of rows = 10, buffer_size = 4, so assert that count =  10/4 = 3
 \set expected_row_count 3
 DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
-SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out',  'length>0.2',  'ARRAY[diameter,height,whole,shucked,viscera,shell]', 4);
+SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out',  'length>0.2',  'ARRAY[diameter,height,whole,shucked,viscera,shell]', NULL, 4);
 SELECT assert
         (
-        row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out.
+        row_count = :expected_row_count, 'Row count validation failed.
         Expected:' || :expected_row_count || ' Actual: ' || row_count
         ) from (select count(*) as row_count from minibatch_preprocessing_out) s;
 
-\set expected_dep_row_count '\'' 2,4,4 '\''
-\set expected_dep_col_count '\'' 2,2,2 '\''
-\set expected_indep_row_count '\'' 2,4,4 '\''
-\set expected_indep_col_count '\'' 6,6,6 '\''
-
 -- assert dimensions for both dependent and independent variable
-SELECT assert
-        (
-        str_dep_row_count = :expected_dep_row_count, 'Dependent variable row count failed. Actual: ' || str_dep_row_count || ' Expected:' || :expected_dep_row_count
-        ) from
-        (
-        select array_to_string(array_agg(row_count order by row_count asc), ',') as str_dep_row_count from (select array_upper(dependent_varname,1) as row_count from minibatch_preprocessing_out order by row_count asc) s
-        ) s;
+SELECT assert_col_dimension('dependent_varname', 1 , '{2,4,4}', NULL);
+SELECT assert_col_dimension('dependent_varname', 2 , '{2,2,2}', NULL);
+SELECT assert_col_dimension('independent_varname', 1 , '{2,4,4}', NULL);
+SELECT assert_col_dimension('independent_varname', 2 , '{6,6,6}', NULL);
 
 SELECT assert
         (
-        str_dep_col_count = :expected_dep_col_count, 'Dependent variable col count failed. Actual: ' || str_dep_col_count || ' Expected:' || :expected_dep_col_count
-        ) from
-        (
-        select array_to_string(array_agg(col_count order by col_count asc), ',') as str_dep_col_count from (select array_upper(dependent_varname,2) as col_count from minibatch_preprocessing_out order by col_count asc) s
-        ) s;
+        source_table        = 'minibatch_preprocessing_input' AND
+        output_table        = 'minibatch_preprocessing_out' AND
+        dependent_varname   = 'length>0.2' AND
+        independent_varname = 'ARRAY[diameter,height,whole,shucked,viscera,shell]' AND
+        buffer_size         = 4 AND
+        class_values        = '{f,t}' AND -- we sort the class values in python
+        num_rows_processed  = 10 AND
+        num_missing_rows_skipped    = 0 AND
+        grouping_cols       is NULL,
+        'Summary Validation failed. Actual:' || __to_char(summary)
+        ) from (select * from minibatch_preprocessing_out_summary) summary;
 
+-- grouping with the same dataset
+\set expected_grouping_row_count '\'' 1,1,2 '\''
+DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
+SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out',  'length>0.2',  'ARRAY[diameter,height,whole,shucked,viscera,shell]', 'rings', 4);
 SELECT assert
         (
-        str_indep_row_count = :expected_indep_row_count, 'Independent variable row count failed. Actual: ' || str_indep_row_count || ' Expected:' || :expected_indep_row_count
+        str_row_count = :expected_grouping_row_count, 'Row count validation failed for minibatch_preprocessing grouping.
+        Expected:' || :expected_grouping_row_count || ' Actual: ' || str_row_count
         ) from
         (
-        select array_to_string(array_agg(row_count order by row_count asc), ',') as str_indep_row_count from (select array_upper(independent_varname, 1) as row_count from minibatch_preprocessing_out order by row_count asc) s
-        ) s;
+            select array_to_string(array_agg(row_count order by row_count asc),',') as str_row_count from
+                (
+                    select  count(*) as row_count
+                    from minibatch_preprocessing_out group by rings order by rings
+                ) s
+        ) s1;
 
-SELECT assert
-        (
-        str_indep_col_count = :expected_indep_col_count, 'Independent variable col count failed. Actual: ' || str_indep_col_count || ' Expected:' || :expected_indep_col_count
-        ) from
-        (
-        select array_to_string(array_agg(col_count order by col_count asc), ',') as str_indep_col_count from (select array_upper(independent_varname,2) as col_count from minibatch_preprocessing_out order by col_count asc) s
-        ) s;
+-- assert dimensions for both dependent and independent variable, notice that in
+-- each query, the result is order by grouping_cols and rowcount so we end up
+-- having one str of array
+SELECT assert_col_dimension('dependent_varname', 1 , '{1,1,4,4}', 'rings');
+SELECT assert_col_dimension('dependent_varname', 2 , '{2,2,2,2}', 'rings');
+SELECT assert_col_dimension('independent_varname', 1 , '{1,1,4,4}', 'rings');
+SELECT assert_col_dimension('independent_varname', 2 , '{6,6,6,6}', 'rings');
 
 SELECT assert
         (
@@ -101,33 +142,50 @@ SELECT assert
         class_values        = '{f,t}' AND -- we sort the class values in python
         num_rows_processed  = 10 AND
         num_missing_rows_skipped    = 0 AND
-        grouping_cols       is NULL,
-        'Summary Validation failed. Expected:' || __to_char(summary)
+        grouping_cols       = 'rings',
+        'Summary Validation failed for grouping col. Expected:' || __to_char(summary)
         ) from (select * from minibatch_preprocessing_out_summary) summary;
 
+-- Test that the standardization table gets created.
+select count(*) from minibatch_preprocessing_out_standardization;
+-- Test that the summary table gets created.
+select count(*) from minibatch_preprocessing_out_summary;
 
--- Test null values in x and y
-\set expected_row_count 1
+-- Test null values in x and y both with and without grouping
 DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
-
 TRUNCATE TABLE minibatch_preprocessing_input;
 INSERT INTO minibatch_preprocessing_input(id,sex,length,diameter,height,whole,shucked,viscera,shell,rings) VALUES
-(1040,'F',0.66,0.475,0.18,NULL,0.641,0.294,0.335,6),
-(3160,'F',0.34,0.35,0.085,0.204,0.097,0.021,0.05,6),
+(1040,'F',0.66,0.475,0.18,NULL,0.641,0.294,0.335,5),
 (3984,NULL,0.585,0.45,0.25,0.874,0.3545,0.2075,0.225,5),
-(861,'M',0.595,0.475,NULL,1.1405,0.547,0.231,0.271,6),
-(932,NULL,0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6),
+(861,'M',0.595,0.475,NULL,1.1405,0.547,0.231,0.271,5),
+(932,NULL,0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,5),
+(698,'F',0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6),
+(698,'F',0.445,0.335,0.11,0.4355,NULL,0.1095,0.1195,6),
 (698,'F',0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6),
-(922,NULL,0.445,0.335,0.11,NULL,0.2025,0.1095,0.1195,6);
-SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'ARRAY[length,diameter,height,whole,shucked,viscera,shell]', 2);
+(922,NULL,0.445,0.335,0.11,NULL,0.2025,0.1095,0.1195,5),
+(942,'F',0.445,0.335,0.11,0.25,0.2025,0.1095,0.1195,5);
+SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'ARRAY[length,diameter,height,whole,shucked,viscera,shell]', NULL, 2);
 SELECT assert
         (
-        row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out.
-        Expected:' || :expected_row_count || ' Actual: ' || row_count
+        row_count = 2, 'Row count validation failed with null values.
+        Expected:' || 2 || ' Actual: ' || row_count
         ) from (select count(*) as row_count from minibatch_preprocessing_out) s;
 SELECT assert
-        (num_rows_processed = 2 AND num_missing_rows_skipped = 5,
-         'Rows processed/skipped validation failed for minibatch_preprocessing_out_summary.
+        (num_rows_processed = 3 AND num_missing_rows_skipped = 6,
+         'Rows processed/skipped validation failed_summary.
+         Actual num_rows_processed:' || num_rows_processed || ', Actual num_missing_rows_skipped: ' || num_missing_rows_skipped
+        ) from (select * from minibatch_preprocessing_out_summary) s;
+
+DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
+  SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'ARRAY[length,diameter,height,whole,shucked,viscera,shell]', 'rings', 2);
+SELECT assert
+        (
+        row_count = 2, 'Row count validation failed with null values and grouping.
+        Expected:' || 2 || ' Actual: ' || row_count
+        ) from (select count(*) as row_count from minibatch_preprocessing_out) s;
+SELECT assert
+        (num_rows_processed = 3 AND num_missing_rows_skipped = 6,
+         'Rows processed/skipped validation failed_summary with grouping.
          Actual num_rows_processed:' || num_rows_processed || ', Actual num_missing_rows_skipped: ' || num_missing_rows_skipped
         ) from (select * from minibatch_preprocessing_out_summary) s;
 
@@ -138,7 +196,7 @@ CREATE TABLE minibatch_preprocessing_input(x1 INTEGER ,x2 INTEGER ,y TEXT);
 INSERT INTO minibatch_preprocessing_input(x1,x2,y) VALUES
 (2,10,'y1'),
 (4,30,'y2');
-SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'y', 'ARRAY[x1,x2]', 2);
+SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'y', 'ARRAY[x1,x2]', NULL, 2);
 
 -- since the order is not deterministic, we assert for all possible orders
 \set expected_normalized_independent_var1 '\'' {{-1, -1},{1, 1}} '\''
@@ -154,28 +212,10 @@ SELECT assert
     select __to_char(independent_varname) as independent_varname from minibatch_preprocessing_out
 ) s;
 
-
 -- Test that the standardization table gets created.
-\set expected_row_count 1
-SELECT assert
-(
-  row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out_standardization.
-        Expected:' || :expected_row_count || ' Actual: ' || row_count
-) from
-(
-  select count(*) as row_count from minibatch_preprocessing_out_standardization
-) s;
-
+select count(*) from minibatch_preprocessing_out_standardization;
 -- Test that the summary table gets created.
-\set expected_row_count 1
-SELECT assert
-(
-  row_count = :expected_row_count, 'Row count validation failed for minibatch_preprocessing_out_summary.
-        Expected:' || :expected_row_count || ' Actual: ' || row_count
-) from
-(
-  select count(*) as row_count from minibatch_preprocessing_out_summary
-) s;
+select count(*) from minibatch_preprocessing_out_summary;
 
 -- Test for array values in indep column
 DROP TABLE IF EXISTS minibatch_preprocessing_input;
@@ -194,19 +234,19 @@ INSERT INTO minibatch_preprocessing_input(id,sex,attributes) VALUES
 (932,NULL,ARRAY[0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195]),
 (NULL,'F',ARRAY[0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195]),
 (922,NULL,ARRAY[0.445,0.335,0.11,NULL,0.2025,0.1095,0.1195]);
-SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'attributes', 1);
+SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'attributes', NULL, 1);
 SELECT assert
         (
-        row_count = 2, 'Row count validation failed for minibatch_preprocessing_out.
+        row_count = 2, 'Row count validation failed with array values in independent variable.
         Expected:' || 2 || ' Actual: ' || row_count
         ) from (select count(*) as row_count from minibatch_preprocessing_out) s;
 
 -- Test for array values in dep column
 DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
-SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'attributes', 'ARRAY[id]', 1);
+SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'attributes', 'ARRAY[id]', NULL, 1);
 SELECT assert
         (
-        row_count = 3, 'Row count validation failed array values in dependent variable.
+        row_count = 3, 'Row count validation failed with array values in dependent variable.
         Expected:' || 3 || ' Actual: ' || row_count
         ) from (select count(*) as row_count from minibatch_preprocessing_out) s;
 
@@ -219,3 +259,28 @@ SELECT assert
         Buffer size from summary table: ' || buffer_size || ' does not match the output table:'
         || ind_var_rows
         ) from (select max(array_upper(o.dependent_varname, 1)) as dep_var_rows, max(array_upper(o.independent_varname, 1)) as ind_var_rows , s1.buffer_size from minibatch_preprocessing_out o, minibatch_preprocessing_out_summary s1 group by buffer_size) s;
+
+-- Test special characters in independent_var, dependent_var and grouping_cols
+DROP TABLE IF EXISTS minibatch_preprocessing_input;
+CREATE TABLE minibatch_preprocessing_input(
+    "se''x" TEXT,
+    "len'%*()gth" DOUBLE PRECISION[],
+    "rin!#'gs" INTEGER);
+
+INSERT INTO minibatch_preprocessing_input VALUES
+('F',ARRAY[0.66, 0.5],6);
+DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
+SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', '"se''''x"', '"len''%*()gth"', '"rin!#''gs"');
+SELECT assert
+        (
+        source_table        = 'minibatch_preprocessing_input' AND
+        output_table        = 'minibatch_preprocessing_out' AND
+        dependent_varname   = '"se''''x"' AND
+        independent_varname = '"len''%*()gth"' AND
+        buffer_size         = 1 AND
+        class_values        = '{F}' AND -- we sort the class values in python
+        num_rows_processed  = 1 AND
+        num_missing_rows_skipped    = 0 AND
+        grouping_cols       = '"rin!#''gs"',
+        'Summary Validation failed for special chars. Expected:' || __to_char(summary)
+        ) from (select * from minibatch_preprocessing_out_summary) summary;

http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
index c8b6942..5f83e87 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
@@ -50,6 +50,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
         self.default_output_table = "output"
         self.default_dep_var = "depvar"
         self.default_ind_var = "indvar"
+        self.grouping_cols = None
         self.default_buffer_size = 5
 
         import minibatch_preprocessing
@@ -77,6 +78,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
                                                          "out",
                                                          self.default_dep_var,
                                                          self.default_ind_var,
+                                                         self.grouping_cols,
                                                          self.default_buffer_size)
         self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
                                                 "num_rows_processed":3}], ""]
@@ -90,6 +92,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
                                                          "out",
                                                          self.default_dep_var,
                                                          self.default_ind_var,
+                                                         self.grouping_cols,
                                                          None)
         self.plpy_mock_execute.side_effect = [[{"source_table_row_count":5 ,
         "num_rows_processed":3}], ""]
@@ -104,6 +107,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
                                                                      self.default_output_table,
                                                                      "y1,y2",
                                                                      self.default_ind_var,
+                                                                     self.grouping_cols,
                                                                      self.default_buffer_size)
 
     def test_minibatch_preprocessor_buffer_size_zero_fails(self):
@@ -113,6 +117,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
                                                              self.default_output_table,
                                                              self.default_dep_var,
                                                              self.default_ind_var,
+                                                             self.grouping_cols,
                                                              0)
 
     def test_minibatch_preprocessor_buffer_size_one_passes(self):
@@ -122,6 +127,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
                                                              self.default_output_table,
                                                              self.default_dep_var,
                                                              self.default_ind_var,
+                                                             self.grouping_cols,
                                                              1)
         preprocessor_obj.minibatch_preprocessor()
         self.assert_(True)
@@ -152,7 +158,8 @@ class MiniBatchQueryFormatterTestCase(unittest.TestCase):
         self.module_patcher.stop()
 
     def test_get_dep_var_array_str_text_type(self):
-        self.plpy_mock_execute.return_value = [{"class":100},{"class":0},{"class":22}]
+        self.plpy_mock_execute.return_value = [{"class":100},{"class":0},
+                                               {"class":22}]
 
         dep_var_array_str, _ = self.subject.get_dep_var_array_and_classes\
                                                 (self.default_dep_var, 'text')
@@ -166,13 +173,15 @@ class MiniBatchQueryFormatterTestCase(unittest.TestCase):
         self.plpy_mock_execute.return_value = [{"class":3}]
 
         dep_var_array_str, _ = self.subject.\
-                            get_dep_var_array_and_classes(self.default_dep_var, 'boolean')
+                            get_dep_var_array_and_classes(self.default_dep_var,
+                                                          'boolean')
         self.assertEqual("ARRAY[({0}) = '3']::integer[]".
                          format(self.default_dep_var), dep_var_array_str)
 
     def test_get_dep_var_array_str_array_type(self):
         dep_var_array_str, _ = self.subject.\
-                        get_dep_var_array_and_classes(self.default_dep_var, 'some_array[]')
+                        get_dep_var_array_and_classes(self.default_dep_var,
+                                                      'some_array[]')
 
         self.assertEqual(self.default_dep_var, dep_var_array_str)
 
@@ -180,11 +189,13 @@ class MiniBatchQueryFormatterTestCase(unittest.TestCase):
         dep_var_array_str, _ = self.subject. \
             get_dep_var_array_and_classes(self.default_dep_var, 'integer')
 
-        self.assertEqual("ARRAY[{0}]".format(self.default_dep_var), dep_var_array_str)
+        self.assertEqual("ARRAY[{0}]".format(self.default_dep_var),
+                                            dep_var_array_str)
 
     def test_get_dep_var_array_str_other_type(self):
         with self.assertRaises(Exception):
-            self.subject.get_dep_var_array_and_classes(self.default_dep_var, 'other')
+            self.subject.get_dep_var_array_and_classes(self.default_dep_var,
+                                                       'other')
 
     def test_get_indep_var_array_str_passes(self):
         ind_var_array_str = self.subject.get_indep_var_array_str('ARRAY[x1,x2,x3]')
@@ -203,7 +214,9 @@ class MiniBatchQueryStandardizerTestCase(unittest.TestCase):
         }
         self.x_mean = "5678"
         self.x_std_dev = "4.789"
-        self.mean_std_calculator_mock.MeanStdDevCalculator.return_value.get_mean_and_std_dev_for_ind_var = Mock(return_value=(self.x_mean, self.x_std_dev))
+        self.mean_std_calculator_mock.MeanStdDevCalculator.return_value.\
+            get_mean_and_std_dev_for_ind_var = \
+            Mock(return_value=(self.x_mean, self.x_std_dev))
 
         # we need to use MagicMock() instead of Mock() for the plpy.execute mock
         # to be able to iterate on the return value
@@ -219,42 +232,31 @@ class MiniBatchQueryStandardizerTestCase(unittest.TestCase):
                                                          self.default_source_table,
                                                          self.default_dep_var,
                                                          self.default_ind_var,
+                                                         None,
                                                          "out_standardization")
 
     def tearDown(self):
         self.module_patcher.stop()
 
-    def test_get_query_for_standardizing_no_exception(self):
+    def test_get_query_for_standardizing_no_grouping(self):
         self.subject.get_query_for_standardizing()
+        self.assertEqual(self.x_mean, self.subject.x_mean_str)
+        self.assertEqual(self.x_std_dev, self.subject.x_std_dev_str)
 
-    def test_get_query_for_standardizing_null_mean_raises_exception(self):
-        self.mean_std_calculator_mock.MeanStdDevCalculator.return_value.get_mean_and_std_dev_for_ind_var = Mock(return_value=(None, self.x_std_dev))
-        with self.assertRaises(Exception):
-            self.module.MiniBatchStandardizer(self.default_schema,
-                                              self.default_source_table,
-                                              self.default_dep_var,
-                                              self.default_ind_var,
-                                              "does_not_matter")
-
-    def test_get_query_for_standardizing_null_stddev_raises_exception(self):
-        self.mean_std_calculator_mock.MeanStdDevCalculator.return_value.get_mean_and_std_dev_for_ind_var = Mock(return_value=(self.x_mean, None))
-        with self.assertRaises(Exception):
-            self.module.MiniBatchStandardizer(self.default_schema,
-                                              self.default_source_table,
-                                              self.default_dep_var,
-                                              self.default_ind_var,
-                                              "does_not_matter")
-
-    def test_get_calculated_mean_and_std_dev_returns_values(self):
+    def test_get_query_for_standardizing_grouping(self):
+        self.subject = self.module.MiniBatchStandardizer(self.default_schema,
+                                                         self.default_source_table,
+                                                         self.default_dep_var,
+                                                         self.default_ind_var,
+                                                         "grp",
+                                                         "out_standardization")
         self.subject.get_query_for_standardizing()
-        mean, std_dev = self.subject.x_mean_str, self.subject.x_std_dev_str
-        self.assertEqual(self.x_mean, mean)
-        self.assertEqual(self.x_std_dev, std_dev)
 
     def test_create_standardization_output_table_executes_query(self):
         self.subject.create_output_standardization_table()
         expected_query_substr_create_table = "CREATE TABLE out_standardization AS"
-        self.plpy_mock_execute.assert_called_with(AnyStringWith(expected_query_substr_create_table))
+        self.plpy_mock_execute.assert_called_with(AnyStringWith(
+                                            expected_query_substr_create_table))
         self.plpy_mock_execute.assert_called_with(AnyStringWith(self.x_mean))
         self.plpy_mock_execute.assert_called_with(AnyStringWith(self.x_std_dev))
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
index 1109eeb..c40142d 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
@@ -61,91 +61,81 @@ class UtilitiesTestCase(unittest.TestCase):
     def tearDown(self):
         self.module_patcher.stop()
 
-    def test_validate_module_input_params_all_nulls(self):
-        with self.assertRaises(Exception) as context:
-            self.subject.validate_module_input_params(None, None, None, None, "unittest_module")
-
-        expected_exception = Exception("unittest_module error: NULL/empty input table name!")
-        self.assertEqual(expected_exception.message, context.exception.message)
-
-    def test_validate_module_input_params_source_table_null(self):
-        with self.assertRaises(Exception) as context:
-            self.subject.validate_module_input_params(None, self.default_output_table,
-                                                      self.default_ind_var,
-                                                      self.default_dep_var,
+    def test_validate_module_input_params_source_and_output_table_are_tested(self):
+        self.subject.input_tbl_valid = Mock()
+        self.subject.output_tbl_valid = Mock()
+        self.subject.validate_module_input_params(self.default_source_table,
+                                                  self.default_output_table,
+                                                  self.default_ind_var,
+                                                  self.default_dep_var,
+                                                  self.default_module, None)
+        self.subject.input_tbl_valid.assert_any_call(self.default_source_table,
                                                       self.default_module)
-
-        expected_exception = "unittest_module error: NULL/empty input table name!"
-        self.assertEqual(expected_exception, context.exception.message)
-
-    def test_validate_module_input_params_output_table_null(self):
-        with self.assertRaises(Exception) as context:
-            self.subject.validate_module_input_params(self.default_source_table, None,
-                                                      self.default_ind_var,
-                                                      self.default_dep_var,
+        self.subject.output_tbl_valid.assert_any_call(self.default_output_table,
                                                       self.default_module)
 
-        expected_exception = "unittest_module error: NULL/empty output table name!"
-        self.assertEqual(expected_exception, context.exception.message)
-
-    @patch('validate_args.table_exists', return_value=Mock())
-    def test_validate_module_input_params_output_table_exists(self,
-                                                              table_exists_mock):
+    def test_validate_module_input_params_assert_other_tables_dont_exist(self):
         self.subject.input_tbl_valid = Mock()
-        table_exists_mock.side_effect = [True]
-        with self.assertRaises(Exception) as context:
-            self.subject.validate_module_input_params(self.default_source_table,
+        self.subject.output_tbl_valid = Mock()
+        self.subject.validate_module_input_params(self.default_source_table,
                                                       self.default_output_table,
                                                       self.default_ind_var,
                                                       self.default_dep_var,
-                                                      self.default_module)
+                                                      self.default_module,
+                                                      None,
+                                                      ['foo','bar'])
+        self.subject.output_tbl_valid.assert_any_call('foo', self.default_module)
+        self.subject.output_tbl_valid.assert_any_call('bar', self.default_module)
 
-        expected_exception = "unittest_module error: Output table '{0}' already exists.".format(self.default_output_table)
-        self.assertTrue(expected_exception in context.exception.message)
 
-    @patch('validate_args.table_exists', return_value=Mock())
-    def test_validate_module_input_params_assert_other_tables_dont_exist(self, table_exists_mock):
+    def test_validate_module_input_params_ind_var_null(self):
         self.subject.input_tbl_valid = Mock()
-        table_exists_mock.side_effect = [False, False, True]
+        self.subject.output_tbl_valid = Mock()
+        self.subject.is_var_valid = Mock(side_effect = [False, True, True])
         with self.assertRaises(Exception) as context:
             self.subject.validate_module_input_params(self.default_source_table,
                                                       self.default_output_table,
-                                                      self.default_ind_var,
+                                                      "invalid_indep_var",
                                                       self.default_dep_var,
                                                       self.default_module,
-                                                      ['foo','bar'])
+                                                      None)
 
-        expected_exception = "unittest_module error: Output table 'bar' already exists."
-        self.assertTrue(expected_exception in context.exception.message)
+        expected_exception = "unittest_module error: invalid independent_varname " \
+                             "('invalid_indep_var') for source_table (source)!"
+        self.assertEqual(expected_exception, context.exception.message)
+
+    def test_validate_module_input_params_dep_var_invalid(self):
+        self.subject.input_tbl_valid = Mock()
+        self.subject.output_tbl_valid = Mock()
+        self.subject.is_var_valid = Mock(side_effect = [True, False, True])
 
-    @patch('validate_args.table_is_empty', return_value=False)
-    @patch('validate_args.table_exists', return_value=Mock())
-    def test_validate_module_input_params_ind_var_null(self, table_exists_mock,
-                                                             table_is_empty_mock):
-        table_exists_mock.side_effect = [True, False]
         with self.assertRaises(Exception) as context:
             self.subject.validate_module_input_params(self.default_source_table,
                                                       self.default_output_table,
-                                                      None,
-                                                      self.default_dep_var,
-                                                      self.default_module)
+                                                      self.default_ind_var,
+                                                      "invalid_dep_var",
+                                                      self.default_module, None)
 
-        expected_exception = "unittest_module error: invalid independent_varname ('None') for source_table (source)!"
+        expected_exception = "unittest_module error: invalid dependent_varname " \
+                             "('invalid_dep_var') for source_table (source)!"
         self.assertEqual(expected_exception, context.exception.message)
-        # is_var_valid_mock.assert_called_once_with(self.default_source_table, self.default_ind_var)
 
-    @patch('validate_args.table_exists', return_value=Mock())
-    @patch('validate_args.table_is_empty', return_value=False)
-    def test_validate_module_input_params_dep_var_null(self, table_is_empty_mock, table_exists_mock):
-        table_exists_mock.side_effect = [True, False]
+    def test_validate_module_input_params_grouping_cols_invalid(self):
+        self.subject.input_tbl_valid = Mock()
+        self.subject.output_tbl_valid = Mock()
+        is_var_valid_mock = Mock()
+        is_var_valid_mock.side_effect = [True, True, False]
+        self.subject.is_var_valid = is_var_valid_mock
         with self.assertRaises(Exception) as context:
             self.subject.validate_module_input_params(self.default_source_table,
-                                                      self.default_output_table,
-                                                      self.default_ind_var,
-                                                      None,
-                                                      self.default_module)
-
-        expected_exception = "unittest_module error: invalid dependent_varname ('None') for source_table (source)!"
+                                                  self.default_output_table,
+                                                  self.default_ind_var,
+                                                  self.default_dep_var,
+                                                  self.default_module,
+                                                  'invalid_grp_col')
+
+        expected_exception = "unittest_module error: invalid grouping_cols " \
+                             "('invalid_grp_col') for source_table (source)!"
         self.assertEqual(expected_exception, context.exception.message)
 
     def test_is_var_valid_all_nulls(self):

http://git-wip-us.apache.org/repos/asf/madlib/blob/63261dfb/src/ports/postgres/modules/utilities/utilities.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 133f4ac..320082c 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -795,7 +795,7 @@ def collate_plpy_result(plpy_result_rows):
 
 
 def validate_module_input_params(source_table, output_table, independent_varname,
-                                 dependent_varname, module_name,
+                                 dependent_varname, module_name, grouping_cols,
                                  other_output_tables=None):
     """
     This function is supposed to be used for validating params for
@@ -833,6 +833,13 @@ def validate_module_input_params(source_table, output_table, independent_varname
             "({source_table})!".format(module_name=module_name,
                                        dependent_varname=dependent_varname,
                                        source_table=source_table))
+    if grouping_cols:
+        _assert(is_var_valid(source_table, grouping_cols),
+                "{module_name} error: invalid grouping_cols "
+                "('{grouping_cols}') for source_table "
+                "({source_table})!".format(module_name=module_name,
+                                           grouping_cols=grouping_cols,
+                                           source_table=source_table))
 # ------------------------------------------------------------------------
 
 import unittest


[2/3] madlib git commit: Utilities: Add unit test file for validate args

Posted by nj...@apache.org.
Utilities: Add unit test file for validate args

This commit adds a new unittest file for the validate_args python file.
The only two functions tested right now are input_tbl_valid and
output_tbl_valid.


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/00c42af6
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/00c42af6
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/00c42af6

Branch: refs/heads/master
Commit: 00c42af65226e07c1205e3381b02374d7cfa0d64
Parents: 63261df
Author: Nikhil Kak <nk...@pivotal.io>
Authored: Thu Mar 29 20:09:53 2018 -0700
Committer: Nandish Jayaram <nj...@apache.org>
Committed: Thu Apr 5 15:45:50 2018 -0700

----------------------------------------------------------------------
 .../test/unit_tests/test_validate_args.py_in    | 126 +++++++++++++++++++
 1 file changed, 126 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/00c42af6/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
new file mode 100644
index 0000000..8c59256
--- /dev/null
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
@@ -0,0 +1,126 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+from os import path
+# Add utilites module to the pythonpath.
+sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
+
+
+import unittest
+from mock import *
+import sys
+import plpy_mock as plpy
+
+m4_changequote(`<!', `!>')
+class ValidateArgsTestCase(unittest.TestCase):
+    def setUp(self):
+        patches = {
+            'plpy': plpy
+        }
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+
+        import validate_args
+        self.subject = validate_args
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_input_tbl_valid_null_tbl_raises_exception(self):
+        with self.assertRaises(Exception) as context:
+          self.subject.input_tbl_valid(None, "unittest_module")
+        
+        self.assertEqual("unittest_module error: NULL/empty input"
+                          " table name!", context.exception.message)
+
+    def test_input_tbl_valid_whitespaces_tbl_raises(self):
+        with self.assertRaises(Exception) as context:
+          self.subject.input_tbl_valid("  ", "unittest_module")
+        
+        self.assertEqual("unittest_module error: NULL/empty input"
+                          " table name!", context.exception.message)
+
+    def test_input_tbl_valid_table_not_exists_raises(self):
+				self.subject.table_exists = Mock(return_value=False)
+				with self.assertRaises(Exception) as context:
+					self.subject.input_tbl_valid("foo", "unittest_module")
+        
+				self.assertEqual("unittest_module error: Input table"
+                          " 'foo' does not exist", context.exception.message)
+
+    def test_input_tbl_valid_table_tbl_empty_raises(self):
+				self.subject.table_exists = Mock(return_value=True)
+				self.subject.table_is_empty = Mock(return_value=True)
+				with self.assertRaises(Exception) as context:
+					self.subject.input_tbl_valid("foo", "unittest_module")
+        
+				self.assertEqual("unittest_module error: Input table"
+                          " 'foo' is empty!", context.exception.message)
+
+    def test_input_tbl_valid_table_tbl_empty_passes(self):
+				self.subject.table_exists = Mock(return_value=True)
+				self.subject.table_is_empty = Mock(return_value=True)
+				self.subject.input_tbl_valid("foo", "unittest_module",
+																				check_empty = False)
+
+    def test_input_tbl_valid_table_passes(self):
+				self.subject.table_exists = Mock(return_value=True)
+				self.subject.table_is_empty = Mock(return_value=False)
+				self.subject.input_tbl_valid("foo", "unittest_module")
+        
+    def test_output_tbl_valid_null_tbl_raises_exception(self):
+        with self.assertRaises(Exception) as context:
+          self.subject.output_tbl_valid(None, "unittest_module")
+        
+        self.assertEqual("unittest_module error: NULL/empty output"
+                          " table name!", context.exception.message)
+
+    def test_output_tbl_valid_whitespaces_tbl_raises_exception(self):
+        with self.assertRaises(Exception) as context:
+          self.subject.output_tbl_valid("  ", "unittest_module")
+        
+        self.assertEqual("unittest_module error: NULL/empty output"
+                          " table name!", context.exception.message)
+
+    def test_output_tbl_valid_null_raises(self):
+        with self.assertRaises(Exception) as context:
+          self.subject.output_tbl_valid(" null ", "unittest_module")
+        
+        self.assertEqual("unittest_module error: NULL/empty output"
+                          " table name!", context.exception.message)
+
+    def test_output_tbl_valid_tbl_exists_raises(self):
+				self.subject.table_exists = Mock(return_value=True)
+				with self.assertRaises(Exception) as context:
+					self.subject.output_tbl_valid(" foo ", "unittest_module")
+        
+				self.assertEqual("unittest_module error: Output table foo"
+												" already exists.\n Drop it before calling"
+												" the function.", context.exception.message)
+
+    def test_output_tbl_valid_table_passes(self):
+				self.subject.table_exists = Mock(return_value=False)
+				self.subject.output_tbl_valid("foo", "unittest_module")
+
+if __name__ == '__main__':
+    unittest.main()


[3/3] madlib git commit: UnitTests: Raise custom exception for mocked plpy error.

Posted by nj...@apache.org.
UnitTests: Raise custom exception for mocked plpy error.

Before this commit, all the unit tests that wanted to assert that
plpy.error was called had to assert that an Exception was raised. This
was too generic and did not distinguish between an exception coming from
the plpy mock class vs any other exception.
With this commit, we now raise a custom plpy exception so that we don't
need to assert for the equality of the error messages. Asserting for the
exception is proof enough that plpy.error was called.


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/ab83c95b
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/ab83c95b
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/ab83c95b

Branch: refs/heads/master
Commit: ab83c95be4dde60d47c6d948172e6f77de773dbb
Parents: 00c42af
Author: Nikhil Kak <nk...@pivotal.io>
Authored: Fri Mar 30 11:29:22 2018 -0700
Committer: Nandish Jayaram <nj...@apache.org>
Committed: Thu Apr 5 15:45:50 2018 -0700

----------------------------------------------------------------------
 .../convex/test/unit_tests/plpy_mock.py_in      | 11 ++-
 .../convex/test/unit_tests/test_mlp_igd.py_in   | 34 ++++----
 .../utilities/test/unit_tests/plpy_mock.py_in   | 11 ++-
 .../test_minibatch_preprocessing.py_in          |  6 +-
 .../test/unit_tests/test_utilities.py_in        | 18 +----
 .../test/unit_tests/test_validate_args.py_in    | 84 ++++++++------------
 6 files changed, 74 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/ab83c95b/src/ports/postgres/modules/convex/test/unit_tests/plpy_mock.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/test/unit_tests/plpy_mock.py_in b/src/ports/postgres/modules/convex/test/unit_tests/plpy_mock.py_in
index 3058830..dd18649 100644
--- a/src/ports/postgres/modules/convex/test/unit_tests/plpy_mock.py_in
+++ b/src/ports/postgres/modules/convex/test/unit_tests/plpy_mock.py_in
@@ -22,7 +22,7 @@ def __init__(self):
     pass
 
 def error(message):
-    raise Exception(message)
+    raise PLPYException(message)
 
 def execute(query):
     pass
@@ -32,3 +32,12 @@ def warning(query):
 
 def info(query):
     print query
+
+
+class PLPYException(Exception):
+    def __init__(self, message):
+        super(PLPYException, self).__init__()
+        self.message = message
+
+    def __str__(self):
+        return repr(self.message)

http://git-wip-us.apache.org/repos/asf/madlib/blob/ab83c95b/src/ports/postgres/modules/convex/test/unit_tests/test_mlp_igd.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/test/unit_tests/test_mlp_igd.py_in b/src/ports/postgres/modules/convex/test/unit_tests/test_mlp_igd.py_in
index d6d1cc1..986687e 100644
--- a/src/ports/postgres/modules/convex/test/unit_tests/test_mlp_igd.py_in
+++ b/src/ports/postgres/modules/convex/test/unit_tests/test_mlp_igd.py_in
@@ -53,16 +53,16 @@ class MLPMiniBatchTestCase(unittest.TestCase):
     @patch('utilities.validate_args.table_exists', return_value=False)
     def test_mlp_preprocessor_input_table_invalid_raises_exception(
                         self, mock1):
-        with self.assertRaises(Exception):
-            self.subject.MLPPreProcessor("input")
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.MLPMinibatchPreProcessor("input")
 
     @patch('utilities.validate_args.table_exists')
     def test_mlp_preprocessor_summary_invalid_raises_exception(self, mock1):
         tbl_exists_mock = Mock()
         tbl_exists_mock.side_effect = [False, True]
         self.subject.table_exists = tbl_exists_mock
-        with self.assertRaises(Exception):
-            self.subject.MLPPreProcessor("input")
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.MLPMinibatchPreProcessor("input")
         tbl_exists_mock.assert_any_call("input_summary")
 
 
@@ -71,8 +71,8 @@ class MLPMiniBatchTestCase(unittest.TestCase):
         tbl_exists_mock = Mock()
         tbl_exists_mock.side_effect = [True, False]
         self.subject.table_exists = tbl_exists_mock
-        with self.assertRaises(Exception):
-            self.subject.MLPPreProcessor("input")
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.MLPMinibatchPreProcessor("input")
         tbl_exists_mock.assert_any_call("input_standardization")
 
     @patch('utilities.validate_args.table_exists')
@@ -80,8 +80,8 @@ class MLPMiniBatchTestCase(unittest.TestCase):
         self.subject.table_exists = Mock()
         self.subject.input_tbl_valid = Mock()
         self.plpy_mock_execute.return_value = [{'key': 'value'}]
-        with self.assertRaises(Exception):
-            self.module = self.subject.MLPPreProcessor("input")
+        with self.assertRaises(plpy.PLPYException):
+            self.module = self.subject.MLPMinibatchPreProcessor("input")
 
 
     @patch('utilities.validate_args.table_exists')
@@ -92,8 +92,8 @@ class MLPMiniBatchTestCase(unittest.TestCase):
                                                 'dependent_varname': 'value',
                                                 'foo': 'bar'}]
 
-        with self.assertRaises(Exception):
-            self.subject.MLPPreProcessor("input")
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.MLPMinibatchPreProcessor("input")
 
     @patch('utilities.validate_args.table_exists')
     def test_mlp_preprocessor_indep_var_not_present_raises_exception(self, mock1):
@@ -103,8 +103,8 @@ class MLPMiniBatchTestCase(unittest.TestCase):
                                                 'dependent_varname': 'value',
                                                 'class_values': 'value'}]
 
-        with self.assertRaises(Exception):
-            self.subject.MLPPreProcessor("input")
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.MLPMinibatchPreProcessor("input")
 
     @patch('utilities.validate_args.table_exists')
     def test_mlp_preprocessor_dep_var_not_present_raises_exception(self, mock1):
@@ -114,8 +114,8 @@ class MLPMiniBatchTestCase(unittest.TestCase):
                                                 'foo': 'value',
                                                 'class_values': 'value'}]
 
-        with self.assertRaises(Exception):
-            self.module = self.subject.MLPPreProcessor("input")
+        with self.assertRaises(plpy.PLPYException):
+            self.module = self.subject.MLPMinibatchPreProcessor("input")
 
     @patch('utilities.validate_args.table_exists')
     def test_mlp_preprocessor_cols_present_returns_dict(self, mock1):
@@ -126,7 +126,7 @@ class MLPMiniBatchTestCase(unittest.TestCase):
                                                 'dependent_varname': 'value',
                                                 'class_values': 'regression',
                                                 'foo': 'bar'}]
-        self.module = self.subject.MLPPreProcessor("input")
+        self.module = self.subject.MLPMinibatchPreProcessor("input")
         self.assertTrue(self.module.preprocessed_summary_dict)
         self.assertEqual(4, len(self.module.preprocessed_summary_dict))
 
@@ -145,11 +145,11 @@ class MLPMiniBatchTestCase(unittest.TestCase):
         self.assertTrue(is_mb_enabled)
 
         self.plpy_mock_execute.return_value = [{'n_x': 1, 'n_y': 2, 'n_z': 4}]
-        with self.assertRaises(Exception):
+        with self.assertRaises(plpy.PLPYException):
             self.subject.check_if_minibatch_enabled('does not matter', 'still does not matter')
 
         self.plpy_mock_execute.return_value = [{'n_x': None, 'n_y': None, 'n_z': None}]
-        with self.assertRaises(Exception):
+        with self.assertRaises(plpy.PLPYException):
             self.subject.check_if_minibatch_enabled('does not matter', 'still does not matter')
 
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/ab83c95b/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in
index 3058830..dd18649 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/plpy_mock.py_in
@@ -22,7 +22,7 @@ def __init__(self):
     pass
 
 def error(message):
-    raise Exception(message)
+    raise PLPYException(message)
 
 def execute(query):
     pass
@@ -32,3 +32,12 @@ def warning(query):
 
 def info(query):
     print query
+
+
+class PLPYException(Exception):
+    def __init__(self, message):
+        super(PLPYException, self).__init__()
+        self.message = message
+
+    def __str__(self):
+        return repr(self.message)

http://git-wip-us.apache.org/repos/asf/madlib/blob/ab83c95b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
index 5f83e87..548a6dc 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_minibatch_preprocessing.py_in
@@ -101,7 +101,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
         self.assertEqual(2, self.plpy_mock_execute.call_count)
 
     def test_minibatch_preprocessor_multiple_dep_var_raises_exception(self):
-            with self.assertRaises(Exception):
+            with self.assertRaises(plpy.PLPYException):
                 self.module.MiniBatchPreProcessor(self.default_schema_madlib,
                                                                      self.default_source_table,
                                                                      self.default_output_table,
@@ -111,7 +111,7 @@ class MiniBatchPreProcessingTestCase(unittest.TestCase):
                                                                      self.default_buffer_size)
 
     def test_minibatch_preprocessor_buffer_size_zero_fails(self):
-        with self.assertRaises(Exception):
+        with self.assertRaises(plpy.PLPYException):
             self.module.MiniBatchPreProcessor(self.default_schema_madlib,
                                                              self.default_source_table,
                                                              self.default_output_table,
@@ -193,7 +193,7 @@ class MiniBatchQueryFormatterTestCase(unittest.TestCase):
                                             dep_var_array_str)
 
     def test_get_dep_var_array_str_other_type(self):
-        with self.assertRaises(Exception):
+        with self.assertRaises(plpy.PLPYException):
             self.subject.get_dep_var_array_and_classes(self.default_dep_var,
                                                        'other')
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/ab83c95b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
index c40142d..0f38a05 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
@@ -92,7 +92,7 @@ class UtilitiesTestCase(unittest.TestCase):
         self.subject.input_tbl_valid = Mock()
         self.subject.output_tbl_valid = Mock()
         self.subject.is_var_valid = Mock(side_effect = [False, True, True])
-        with self.assertRaises(Exception) as context:
+        with self.assertRaises(plpy.PLPYException):
             self.subject.validate_module_input_params(self.default_source_table,
                                                       self.default_output_table,
                                                       "invalid_indep_var",
@@ -100,33 +100,25 @@ class UtilitiesTestCase(unittest.TestCase):
                                                       self.default_module,
                                                       None)
 
-        expected_exception = "unittest_module error: invalid independent_varname " \
-                             "('invalid_indep_var') for source_table (source)!"
-        self.assertEqual(expected_exception, context.exception.message)
-
     def test_validate_module_input_params_dep_var_invalid(self):
         self.subject.input_tbl_valid = Mock()
         self.subject.output_tbl_valid = Mock()
         self.subject.is_var_valid = Mock(side_effect = [True, False, True])
 
-        with self.assertRaises(Exception) as context:
+        with self.assertRaises(plpy.PLPYException):
             self.subject.validate_module_input_params(self.default_source_table,
                                                       self.default_output_table,
                                                       self.default_ind_var,
                                                       "invalid_dep_var",
                                                       self.default_module, None)
 
-        expected_exception = "unittest_module error: invalid dependent_varname " \
-                             "('invalid_dep_var') for source_table (source)!"
-        self.assertEqual(expected_exception, context.exception.message)
-
     def test_validate_module_input_params_grouping_cols_invalid(self):
         self.subject.input_tbl_valid = Mock()
         self.subject.output_tbl_valid = Mock()
         is_var_valid_mock = Mock()
         is_var_valid_mock.side_effect = [True, True, False]
         self.subject.is_var_valid = is_var_valid_mock
-        with self.assertRaises(Exception) as context:
+        with self.assertRaises(plpy.PLPYException):
             self.subject.validate_module_input_params(self.default_source_table,
                                                   self.default_output_table,
                                                   self.default_ind_var,
@@ -134,10 +126,6 @@ class UtilitiesTestCase(unittest.TestCase):
                                                   self.default_module,
                                                   'invalid_grp_col')
 
-        expected_exception = "unittest_module error: invalid grouping_cols " \
-                             "('invalid_grp_col') for source_table (source)!"
-        self.assertEqual(expected_exception, context.exception.message)
-
     def test_is_var_valid_all_nulls(self):
         self.assertEqual(False, self.subject.is_var_valid(None, None))
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/ab83c95b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
index 8c59256..f94ee03 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
@@ -47,80 +47,58 @@ class ValidateArgsTestCase(unittest.TestCase):
         self.module_patcher.stop()
 
     def test_input_tbl_valid_null_tbl_raises_exception(self):
-        with self.assertRaises(Exception) as context:
+        with self.assertRaises(plpy.PLPYException):
           self.subject.input_tbl_valid(None, "unittest_module")
         
-        self.assertEqual("unittest_module error: NULL/empty input"
-                          " table name!", context.exception.message)
-
     def test_input_tbl_valid_whitespaces_tbl_raises(self):
-        with self.assertRaises(Exception) as context:
+        with self.assertRaises(plpy.PLPYException):
           self.subject.input_tbl_valid("  ", "unittest_module")
         
-        self.assertEqual("unittest_module error: NULL/empty input"
-                          " table name!", context.exception.message)
-
     def test_input_tbl_valid_table_not_exists_raises(self):
-				self.subject.table_exists = Mock(return_value=False)
-				with self.assertRaises(Exception) as context:
-					self.subject.input_tbl_valid("foo", "unittest_module")
-        
-				self.assertEqual("unittest_module error: Input table"
-                          " 'foo' does not exist", context.exception.message)
+        self.subject.table_exists = Mock(return_value=False)
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.input_tbl_valid("foo", "unittest_module")
 
     def test_input_tbl_valid_table_tbl_empty_raises(self):
-				self.subject.table_exists = Mock(return_value=True)
-				self.subject.table_is_empty = Mock(return_value=True)
-				with self.assertRaises(Exception) as context:
-					self.subject.input_tbl_valid("foo", "unittest_module")
-        
-				self.assertEqual("unittest_module error: Input table"
-                          " 'foo' is empty!", context.exception.message)
+        self.subject.table_exists = Mock(return_value=True)
+        self.subject.table_is_empty = Mock(return_value=True)
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.input_tbl_valid("foo", "unittest_module")
 
     def test_input_tbl_valid_table_tbl_empty_passes(self):
-				self.subject.table_exists = Mock(return_value=True)
-				self.subject.table_is_empty = Mock(return_value=True)
-				self.subject.input_tbl_valid("foo", "unittest_module",
-																				check_empty = False)
+        self.subject.table_exists = Mock(return_value=True)
+        self.subject.table_is_empty = Mock(return_value=True)
+        self.subject.input_tbl_valid("foo", "unittest_module",
+                                     check_empty=False)
 
     def test_input_tbl_valid_table_passes(self):
-				self.subject.table_exists = Mock(return_value=True)
-				self.subject.table_is_empty = Mock(return_value=False)
-				self.subject.input_tbl_valid("foo", "unittest_module")
-        
+        self.subject.table_exists = Mock(return_value=True)
+        self.subject.table_is_empty = Mock(return_value=False)
+        self.subject.input_tbl_valid("foo", "unittest_module")
+
     def test_output_tbl_valid_null_tbl_raises_exception(self):
-        with self.assertRaises(Exception) as context:
-          self.subject.output_tbl_valid(None, "unittest_module")
-        
-        self.assertEqual("unittest_module error: NULL/empty output"
-                          " table name!", context.exception.message)
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.output_tbl_valid(None, "unittest_module")
+
 
     def test_output_tbl_valid_whitespaces_tbl_raises_exception(self):
-        with self.assertRaises(Exception) as context:
-          self.subject.output_tbl_valid("  ", "unittest_module")
-        
-        self.assertEqual("unittest_module error: NULL/empty output"
-                          " table name!", context.exception.message)
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.output_tbl_valid("  ", "unittest_module")
+
 
     def test_output_tbl_valid_null_raises(self):
-        with self.assertRaises(Exception) as context:
-          self.subject.output_tbl_valid(" null ", "unittest_module")
-        
-        self.assertEqual("unittest_module error: NULL/empty output"
-                          " table name!", context.exception.message)
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.output_tbl_valid(" null ", "unittest_module")
+
 
     def test_output_tbl_valid_tbl_exists_raises(self):
-				self.subject.table_exists = Mock(return_value=True)
-				with self.assertRaises(Exception) as context:
-					self.subject.output_tbl_valid(" foo ", "unittest_module")
-        
-				self.assertEqual("unittest_module error: Output table foo"
-												" already exists.\n Drop it before calling"
-												" the function.", context.exception.message)
+        self.subject.table_exists = Mock(return_value=True)
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.output_tbl_valid(" foo ", "unittest_module")
 
     def test_output_tbl_valid_table_passes(self):
-				self.subject.table_exists = Mock(return_value=False)
-				self.subject.output_tbl_valid("foo", "unittest_module")
+        self.subject.table_exists = Mock(return_value=False)
+        self.subject.output_tbl_valid("foo", "unittest_module")
 
 if __name__ == '__main__':
     unittest.main()