You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nk...@apache.org on 2021/02/09 20:38:17 UTC

[madlib] 02/04: DL: Fix validation in fit, fit multiple, evaluate and predict

This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch dl/fit-mult-null-table-rebase-in-progress
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 17073577229c185a7a2ab453776e1ac781374655
Author: Nikhil Kak <nk...@vmware.com>
AuthorDate: Fri Jan 22 16:43:01 2021 -0800

    DL: Fix validation in fit, fit multiple, evaluate and predict
    
    JIRA: MADLIB-1464
    
    Previously while calling fit/fit_multiple/evaluate/predict with invalid
    input and output tables (null or missing), we would print the wrong
    error message. This commit refactors the code so that we print the
    expected error message.
    
    Refactored the validator code such that we don't need to create the info
    and summary table names in the fit multiple class. Instead we do that in
    the validator and then the validator object can be used to get the table
    names. This makes it easier to validate all the tables inside the
    validator class.  This commit also refactors the code so that we move
    all the validation code inside the validator class except for the source
    table validation since that needs to be validated before we call the
    get_data_distribution_per_segment function which has to be called before
    the validator constructor.
    
    To test this, we created a plpython function that asserts that the query
    failed with the expected error message. Added a couple of wrapper
    function on top of this function that test for null input and output tables.
    
    Co-authored-by: Ekta Khanna <ek...@vmware.com>
---
 .../modules/deep_learning/madlib_keras.py_in       |  63 ++----
 .../madlib_keras_fit_multiple_model.py_in          |  82 +++-----
 .../deep_learning/madlib_keras_predict.py_in       |   3 +-
 .../deep_learning/madlib_keras_validator.py_in     | 222 ++++++++++-----------
 .../test/madlib_keras_evaluate.sql_in              |   9 +
 .../deep_learning/test/madlib_keras_fit.sql_in     |  42 ++++
 .../test/madlib_keras_model_selection.sql_in       |  37 ++++
 .../test/madlib_keras_multi_io.sql_in              |  25 +++
 .../deep_learning/test/madlib_keras_predict.sql_in |  20 ++
 .../test/madlib_keras_predict_byom.sql_in          |  27 +++
 .../test/unit_tests/test_madlib_keras.py_in        |  33 ++-
 .../postgres/modules/utilities/utilities.sql_in    |  26 +++
 12 files changed, 355 insertions(+), 234 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 49892b6..c4f8611 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -103,6 +103,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     fit_params = "" if not fit_params else fit_params
     _assert(compile_params, "Compile parameters cannot be empty or NULL.")
 
+    input_tbl_valid(source_table, module_name)
     segments_per_host = get_data_distribution_per_segment(source_table)
     use_gpus = use_gpus if use_gpus else False
     if use_gpus:
@@ -114,51 +115,27 @@ def fit(schema_madlib, source_table, model, model_arch_table,
 
     if object_table is not None:
         object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
-
-    source_summary_table = add_postfix(source_table, "_summary")
-    input_tbl_valid(source_summary_table, module_name)
-    src_summary_dict = get_source_summary_table_dict(source_summary_table)
-
-    columns_dict = {}
-    columns_dict['mb_dep_var_cols'] = src_summary_dict['dependent_varname']
-    columns_dict['mb_indep_var_cols'] = src_summary_dict['independent_varname']
-    columns_dict['dep_shape_cols'] = [add_postfix(i, "_shape") for i in columns_dict['mb_dep_var_cols']]
-    columns_dict['ind_shape_cols'] = [add_postfix(i, "_shape") for i in columns_dict['mb_indep_var_cols']]
-
-    multi_dep_count = len(columns_dict['mb_dep_var_cols'])
-    val_dep_var = None
-    val_ind_var = None
-
-    val_dep_shape_cols = None
-    val_ind_shape_cols = None
-    if validation_table:
-        validation_summary_table = add_postfix(validation_table, "_summary")
-        input_tbl_valid(validation_summary_table, module_name)
-        val_summary_dict = get_source_summary_table_dict(validation_summary_table)
-
-        val_dep_var = val_summary_dict['dependent_varname']
-        val_ind_var = val_summary_dict['independent_varname']
-        val_dep_shape_cols = [add_postfix(i, "_shape") for i in val_dep_var]
-        val_ind_shape_cols = [add_postfix(i, "_shape") for i in val_ind_var]
-
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table, model_id,
-        columns_dict['mb_dep_var_cols'], columns_dict['mb_indep_var_cols'],
-        columns_dict['dep_shape_cols'], columns_dict['ind_shape_cols'],
         num_iterations, metrics_compute_frequency, warm_start,
-        use_gpus, accessible_gpus_for_seg, object_table,
-        val_dep_var, val_ind_var)
-
-    columns_dict['val_dep_var'] = val_dep_var
-    columns_dict['val_ind_var'] = val_ind_var
-    columns_dict['val_dep_shape_cols'] = val_dep_shape_cols
-    columns_dict['val_ind_shape_cols'] = val_ind_shape_cols
-
-    fit_validator.dependent_varname = columns_dict['mb_dep_var_cols']
-    fit_validator.independent_varname = columns_dict['mb_indep_var_cols']
-    fit_validator.dep_shape_col = columns_dict['dep_shape_cols']
-    fit_validator.ind_shape_col = columns_dict['ind_shape_cols']
+        use_gpus, accessible_gpus_for_seg, object_table)
 
+    columns_dict = {}
+    columns_dict['mb_dep_var_cols'] = fit_validator.dependent_varname
+    columns_dict['mb_indep_var_cols'] = fit_validator.independent_varname
+    columns_dict['dep_shape_cols'] = fit_validator.dep_shape_cols
+    columns_dict['ind_shape_cols'] = fit_validator.ind_shape_cols
+    columns_dict['val_dep_var'] = fit_validator.val_dep_var
+    columns_dict['val_ind_var'] = fit_validator.val_ind_var
+    columns_dict['val_dep_shape_cols'] = fit_validator.val_dep_shape_cols
+    columns_dict['val_ind_shape_cols'] = fit_validator.val_ind_shape_cols
+    multi_dep_count = len(fit_validator.dependent_varname)
+
+    # fit_validator.dependent_varname = columns_dict['mb_dep_var_cols']
+    # fit_validator.independent_varname = columns_dict['mb_indep_var_cols']
+    # fit_validator.dep_shape_col = columns_dict['dep_shape_cols']
+    # fit_validator.ind_shape_col = columns_dict['ind_shape_cols']
+    src_summary_dict = fit_validator.src_summary_dict
     class_values_colnames = [add_postfix(i, "_class_values") for i in columns_dict['mb_dep_var_cols']]
     src_summary_dict['class_values_type'] =[ get_expr_type(
         i, fit_validator.source_summary_table) for i in class_values_colnames]
@@ -446,6 +423,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
                    normalizing_const_colname=NORMALIZING_CONST_COLNAME,
                    FLOAT32_SQL_TYPE = FLOAT32_SQL_TYPE,
                    model_id_colname = ModelArchSchema.MODEL_ID,
+                   source_summary_table=fit_validator.source_summary_table,
                    **locals()),
                    ["TEXT", "TEXT", "TEXT", "TEXT", "DOUBLE PRECISION[]"])
     plpy.execute(create_output_summary_table,
@@ -867,6 +845,7 @@ def evaluate(schema_madlib, model_table, test_table, output_table,
 
     module_name = 'madlib_keras_evaluate'
     is_mult_model = mst_key is not None
+    test_summary_table = None
     if test_table:
         test_summary_table = add_postfix(test_table, "_summary")
     model_summary_table = None
@@ -874,6 +853,7 @@ def evaluate(schema_madlib, model_table, test_table, output_table,
         model_summary_table = add_postfix(model_table, "_summary")
 
     mult_where_clause = ""
+    input_tbl_valid(model_table, module_name)
     if is_mult_model:
         mult_where_clause = "WHERE mst_key = {0}".format(mst_key)
         model_summary_table = create_summary_view(module_name, model_table, mst_key)
@@ -1035,7 +1015,6 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, columns_dict, compile_
         weights = '$1'
         mult_sql = ''
         custom_map_var = '$2'
-        plpy.info(eval_sql.format(**locals()))
         evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea", "bytea"])
         res = plpy.execute(evaluate_query, [serialized_weights, custom_function_map])
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index deda8f6..22b9401 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -45,6 +45,7 @@ from utilities.utilities import unique_string
 from utilities.utilities import madlib_version
 from utilities.utilities import is_platform_pg
 from utilities.utilities import get_seg_number
+from utilities.validate_args import input_tbl_valid
 import utilities.debug as DEBUG
 from utilities.debug import plpy_prepare
 from utilities.debug import plpy_execute
@@ -110,8 +111,6 @@ class FitMultipleModel(object):
         self.source_table = source_table
         self.validation_table = validation_table
         self.model_selection_table = model_selection_table
-        if self.model_selection_table:
-            self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')
 
         self.dist_key_col = DISTRIBUTION_KEY_COLNAME
         self.prev_dist_key_col = '__prev_dist_key__'
@@ -134,40 +133,6 @@ class FitMultipleModel(object):
         self.train_mst_loss = defaultdict(list)
         self.train_mst_metric = defaultdict(list)
         self.info_str = ""
-        source_summary_table = add_postfix(self.source_table, "_summary")
-        input_tbl_valid(source_summary_table, self.module_name)
-        src_summary_dict = get_source_summary_table_dict(source_summary_table)
-
-        self.mb_dep_var_cols = src_summary_dict['dependent_varname']
-        self.mb_indep_var_cols = src_summary_dict['independent_varname']
-        self.dep_shape_cols = [add_postfix(i, "_shape") for i in self.mb_dep_var_cols]
-        self.ind_shape_cols = [add_postfix(i, "_shape") for i in self.mb_indep_var_cols]
-
-        self.columns_dict = {}
-        self.columns_dict['mb_dep_var_cols'] = self.mb_dep_var_cols
-        self.columns_dict['mb_indep_var_cols'] = self.mb_indep_var_cols
-        self.columns_dict['dep_shape_cols'] = self.dep_shape_cols
-        self.columns_dict['ind_shape_cols'] = self.ind_shape_cols
-
-        self.val_dep_var = None
-        self.val_ind_var = None
-        self.val_dep_shape_cols = None
-        self.val_ind_shape_cols = None
-        if validation_table:
-            validation_summary_table = add_postfix(self.validation_table, "_summary")
-            input_tbl_valid(validation_summary_table, self.module_name)
-            val_summary_dict = get_source_summary_table_dict(validation_summary_table)
-
-            self.val_dep_var = val_summary_dict['dependent_varname']
-            self.val_ind_var = val_summary_dict['independent_varname']
-            self.val_dep_shape_cols = [add_postfix(i, "_shape") for i in self.val_dep_var]
-            self.val_ind_shape_cols = [add_postfix(i, "_shape") for i in self.val_ind_var]
-
-        self.columns_dict['val_dep_var'] = self.val_dep_var
-        self.columns_dict['val_ind_var'] = self.val_ind_var
-        self.columns_dict['val_dep_shape_cols'] = self.val_dep_shape_cols
-        self.columns_dict['val_ind_shape_cols'] = self.val_ind_shape_cols
-
         self.use_gpus = use_gpus if use_gpus else False
         self.model_input_tbl = unique_string('model_input')
         self.model_output_tbl = unique_string('model_output')
@@ -178,6 +143,7 @@ class FitMultipleModel(object):
         self.rotate_schedule_tbl_plan = self.add_object_maps_plan = None
         self.hop_plan = self.udf_plan = None
 
+        input_tbl_valid(self.source_table, self.module_name)
         self.segments_per_host = get_data_distribution_per_segment(source_table)
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
@@ -186,30 +152,32 @@ class FitMultipleModel(object):
             self.accessible_gpus_for_seg = get_seg_number()*[0]
 
         self.original_model_output_tbl = model_output_table
-        if not self.original_model_output_tbl:
-            plpy.error("Must specify an output table.")
-
-        self.model_info_tbl = add_postfix(
-            self.original_model_output_tbl, '_info')
-        self.model_summary_table = add_postfix(
-            self.original_model_output_tbl, '_summary')
-
         self.warm_start = bool(warm_start)
 
         self.fit_validator_train = FitMultipleInputValidator(
             self.source_table, self.validation_table, self.original_model_output_tbl,
-            self.model_selection_table, self.model_selection_summary_table,
-            self.mb_dep_var_cols, self.mb_indep_var_cols, self.dep_shape_cols,
-            self.ind_shape_cols, self.num_iterations,
-            self.model_info_tbl, self.mst_key_col, self.model_arch_table_col,
-            self.metrics_compute_frequency, self.warm_start, self.use_gpus,
-            self.accessible_gpus_for_seg, self.val_dep_var, self.val_ind_var)
+            self.model_selection_table, self.num_iterations, self.mst_key_col,
+            self.model_arch_table_col, self.metrics_compute_frequency,
+            self.warm_start, self.use_gpus, self.accessible_gpus_for_seg)
+        self.model_info_tbl = self.fit_validator_train.output_model_info_table
+        self.model_summary_table = self.fit_validator_train.output_summary_model_table
+        self.model_selection_summary_table = self.fit_validator_train.model_selection_summary_table
         if self.metrics_compute_frequency is None:
             self.metrics_compute_frequency = num_iterations
 
         self.msts = self.fit_validator_train.msts
         self.model_arch_table = self.fit_validator_train.model_arch_table
         self.object_table = self.fit_validator_train.object_table
+        self.columns_dict = {}
+        self.columns_dict['mb_dep_var_cols'] = self.fit_validator_train.dependent_varname
+        self.columns_dict['mb_indep_var_cols'] = self.fit_validator_train.independent_varname
+        self.columns_dict['dep_shape_cols'] = self.fit_validator_train.dep_shape_cols
+        self.columns_dict['ind_shape_cols'] = self.fit_validator_train.ind_shape_cols
+        self.columns_dict['val_dep_var'] = self.fit_validator_train.val_dep_var
+        self.columns_dict['val_ind_var'] = self.fit_validator_train.val_ind_var
+        self.columns_dict['val_dep_shape_cols'] = self.fit_validator_train.val_dep_shape_cols
+        self.columns_dict['val_ind_shape_cols'] = self.fit_validator_train.val_ind_shape_cols
+
         self.metrics_iters = []
         self.object_map_col = 'object_map'
         self.custom_mst_keys = None
@@ -222,7 +190,7 @@ class FitMultipleModel(object):
 
         self.dist_key_mapping, self.images_per_seg_train = \
             get_image_count_per_seg_for_minibatched_data_from_db(
-                self.source_table, self.dep_shape_cols[0])
+                self.source_table, self.fit_validator_train.dep_shape_cols[0])
 
         if self.validation_table:
             self.valid_mst_metric_eval_time = defaultdict(list)
@@ -230,7 +198,7 @@ class FitMultipleModel(object):
             self.valid_mst_metric = defaultdict(list)
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
-                    self.validation_table, self.val_dep_shape_cols[0])
+                    self.validation_table, self.fit_validator_train.val_dep_shape_cols[0])
 
         self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
         self.max_dist_key = sorted(self.dist_keys)[-1]
@@ -713,7 +681,7 @@ class FitMultipleModel(object):
         source_summary_table = self.fit_validator_train.source_summary_table
         src_summary_dict = get_source_summary_table_dict(source_summary_table)
 
-        class_values_colnames = [add_postfix(i, "_class_values") for i in self.mb_dep_var_cols]
+        class_values_colnames = [add_postfix(i, "_class_values") for i in self.fit_validator_train.dependent_varname]
         # class_values = src_summary_dict['class_values']
         class_values_type =[get_expr_type(i, source_summary_table) for i in class_values_colnames]
         # class_values_type = src_summary_dict['class_values_type']
@@ -897,10 +865,10 @@ class FitMultipleModel(object):
             """.format(self=self))
 
         #TODO: Fix these to add multi io
-        dep_shape_col = self.dep_shape_cols[0]
-        ind_shape_col = self.ind_shape_cols[0]
-        dep_var_col = self.mb_dep_var_cols[0]
-        indep_var_col = self.mb_indep_var_cols[0]
+        dep_shape_col = self.fit_validator_train.dep_shape_cols[0]
+        ind_shape_col = self.fit_validator_train.ind_shape_cols[0]
+        dep_var_col = self.fit_validator_train.dependent_varname[0]
+        indep_var_col = self.fit_validator_train.independent_varname[0]
         source_table = self.source_table
 
         if self.use_caching:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 053a5f9..0e5b1b9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -55,6 +55,8 @@ class BasePredict():
         self.module_name = module_name
 
         self.use_gpus = use_gpus if use_gpus else False
+        input_tbl_valid(test_table, module_name)
+        input_tbl_valid(table_to_validate, module_name)
         self.segments_per_host = get_data_distribution_per_segment(test_table)
         if self.use_gpus:
             accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib,
@@ -252,7 +254,6 @@ class Predict(BasePredict):
             plpy.execute("DROP VIEW IF EXISTS {0}".format(self.temp_summary_view))
 
     def validate(self):
-        input_tbl_valid(self.model_table, self.module_name)
         if self.is_mult_model and not columns_exist_in_table(self.model_table, ['mst_key']):
             plpy.error("{self.module_name}: Single model should not pass mst_key".format(**locals()))
         if not self.is_mult_model and columns_exist_in_table(self.model_table, ['mst_key']):
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 2549b84..21eff15 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -260,20 +260,12 @@ class InputValidator:
 
 class FitCommonValidator(object):
     def __init__(self, source_table, validation_table, output_model_table,
-                 model_arch_table, model_id, dependent_varname,
-                 independent_varname, dep_shape_cols, ind_shape_cols, num_iterations,
-                 metrics_compute_frequency, warm_start,
-                 use_gpus, accessible_gpus_for_seg, module_name, object_table,
-                 val_dep_var, val_ind_var):
+                 num_iterations, metrics_compute_frequency, warm_start,
+                 use_gpus, accessible_gpus_for_seg, module_name, object_table):
         self.source_table = source_table
         self.validation_table = validation_table
         self.output_model_table = output_model_table
-        self.model_arch_table = model_arch_table
-        self.model_id = model_id
-        self.dependent_varname = dependent_varname
-        self.independent_varname = independent_varname
-        self.dep_shape_cols = dep_shape_cols
-        self.ind_shape_cols = ind_shape_cols
+
         self.metrics_compute_frequency = metrics_compute_frequency
         self.warm_start = warm_start
         self.num_iterations = num_iterations
@@ -282,42 +274,52 @@ class FitCommonValidator(object):
         if self.source_table:
             self.source_summary_table = add_postfix(
                 self.source_table, "_summary")
+        if self.validation_table:
+            self.validation_summary_table = add_postfix(
+                self.validation_table, "_summary")
         if self.output_model_table:
             self.output_summary_model_table = add_postfix(
                 self.output_model_table, "_summary")
         self.accessible_gpus_for_seg = accessible_gpus_for_seg
         self.module_name = module_name
-        self.val_dep_var = val_dep_var
-        self.val_ind_var = val_ind_var
 
-        self._validate_common_args()
+        self._validate_tables()
+
+        self.src_summary_dict = self.get_source_summary_table_dict(self.source_summary_table)
+
+        self.dependent_varname = self.src_summary_dict['dependent_varname']
+        self.independent_varname = self.src_summary_dict['independent_varname']
+        self.dep_shape_cols = [add_postfix(i, "_shape") for i in self.dependent_varname]
+        self.ind_shape_cols = [add_postfix(i, "_shape") for i in self.independent_varname]
+
+        self.val_dep_var = None
+        self.val_ind_var = None
+        self.val_dep_shape_cols = None
+        self.val_ind_shape_cols = None
+        if self.validation_table:
+            val_summary_dict = self.get_source_summary_table_dict(self.validation_summary_table)
+
+            self.val_dep_var = val_summary_dict['dependent_varname']
+            self.val_ind_var = val_summary_dict['independent_varname']
+            self.val_dep_shape_cols = [add_postfix(i, "_shape") for i in self.val_dep_var]
+            self.val_ind_shape_cols = [add_postfix(i, "_shape") for i in self.val_ind_var]
+
+        self._validate_tables_schema()
         if use_gpus:
             InputValidator._validate_gpu_config(self.module_name,
                 self.source_table, self.accessible_gpus_for_seg)
 
-    def _validate_common_args(self):
-        _assert(self.num_iterations > 0,
-            "{0}: Number of iterations cannot be < 1.".format(self.module_name))
-        _assert(self._is_valid_metrics_compute_frequency(),
-            "{0}: metrics_compute_frequency must be in the range (1 - {1}).".format(
-                self.module_name, self.num_iterations))
+    def _validate_tables(self):
         input_tbl_valid(self.source_table, self.module_name)
+        input_tbl_valid(self.source_summary_table, self.module_name)
+        if self.validation_table:
+            input_tbl_valid(self.validation_table, self.module_name)
+            input_tbl_valid(self.validation_summary_table, self.module_name)
+
         if self.object_table is not None:
             input_tbl_valid(self.object_table, self.module_name)
             cols_in_tbl_valid(self.object_table, CustomFunctionSchema.col_names, self.module_name)
 
-        cols_in_tbl_valid(self.source_summary_table,
-            [NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
-            'dependent_varname', 'independent_varname'], self.module_name)
-        if not is_platform_pg():
-            cols_in_tbl_valid(self.source_table, [DISTRIBUTION_KEY_COLNAME], self.module_name)
-
-        # Source table and validation tables must have the same schema
-        self._validate_input_table(self.source_table)
-        for i in self.dependent_varname:
-            validate_bytea_var_for_minibatch(self.source_table, i)
-
-        self._validate_validation_table()
         if self.warm_start:
             input_tbl_valid(self.output_model_table, self.module_name)
             input_tbl_valid(self.output_summary_model_table, self.module_name)
@@ -325,60 +327,59 @@ class FitCommonValidator(object):
             output_tbl_valid(self.output_model_table, self.module_name)
             output_tbl_valid(self.output_summary_model_table, self.module_name)
 
-    def _validate_input_table(self, table, is_validation_table=False):
-
-        independent_varname = self.val_ind_var if is_validation_table else self.independent_varname
-        dependent_varname = self.val_dep_var if is_validation_table else self.dependent_varname
-
-        for name in independent_varname:
-            _assert(is_var_valid(table, name),
-                "{module_name}: invalid independent_varname "
-                "('{independent_varname}') for table ({table}). "
-                "Please ensure that the input table ({table}) "
-                "has been preprocessed by the image preprocessor.".format(
-                    module_name=self.module_name,
-                    independent_varname=name,
-                    table=table))
-
-        for name in dependent_varname:
-            _assert(is_var_valid(table, name),
-                "{module_name}: invalid dependent_varname "
-                "('{dependent_varname}') for table ({table}). "
-                "Please ensure that the input table ({table}) "
-                "has been preprocessed by the image preprocessor.".format(
-                    module_name=self.module_name,
-                    dependent_varname=name,
-                    table=table))
-        if not is_validation_table:
-            for name in self.ind_shape_cols:
-                _assert(is_var_valid(table, name),
-                    "{module_name}: invalid independent_var_shape "
-                    "('{ind_shape_col}') for table ({table}). "
-                    "Please ensure that the input table ({table}) "
-                    "has been preprocessed by the image preprocessor.".format(
-                        module_name=self.module_name,
-                        ind_shape_col=name,
-                        table=table))
-
-            for name in self.dep_shape_cols:
-                _assert(is_var_valid(table, name),
-                    "{module_name}: invalid dependent_var_shape "
-                    "('{dep_shape_col}') for table ({table}). "
-                    "Please ensure that the input table ({table}) "
-                    "has been preprocessed by the image preprocessor.".format(
-                        module_name=self.module_name,
-                        dep_shape_col=name,
-                        table=table))
 
+    def _validate_tables_schema(self):
+        # Source table and validation tables must have the same schema
+        additional_cols = []
         if not is_platform_pg():
-            _assert(is_var_valid(table, DISTRIBUTION_KEY_COLNAME),
-                    "{module_name}: missing distribution key "
-                    "('{dist_key_col}') for table ({table}). "
-                    "Please ensure that the input table ({table}) "
-                    "has been preprocessed by the image preprocessor.".format(
+            additional_cols.append(DISTRIBUTION_KEY_COLNAME)
+
+        self._validate_columns_in_preprocessed_table(self.source_table,
+                                                    self.independent_varname +
+                                                    self.dependent_varname +
+                                                    self.ind_shape_cols +
+                                                    self.dep_shape_cols +
+                                                    additional_cols)
+        for i in self.dependent_varname:
+            validate_bytea_var_for_minibatch(self.source_table, i)
+
+        if self.validation_table and self.validation_table.strip() != '':
+            self._validate_columns_in_preprocessed_table(self.validation_table,
+                                                        self.val_ind_var +
+                                                        self.val_dep_var +
+                                                        self.val_ind_shape_cols +
+                                                        self.val_dep_shape_cols+
+                                                        additional_cols)
+            for i in self.val_dep_var:
+                validate_bytea_var_for_minibatch(self.validation_table, i)
+
+        cols_in_tbl_valid(self.source_summary_table,
+                          [NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
+                           'dependent_varname', 'independent_varname'], self.module_name)
+
+    def _validate_misc_args(self):
+        _assert(self.num_iterations > 0,
+                "{0}: Number of iterations cannot be < 1.".format(self.module_name))
+        _assert(self._is_valid_metrics_compute_frequency(),
+                "{0}: metrics_compute_frequency must be in the range (1 - {1}).".format(
+                    self.module_name, self.num_iterations))
+
+    def get_source_summary_table_dict(self, source_summary_table):
+        source_summary = plpy.execute("""
+                SELECT *
+                FROM {0}
+            """.format(source_summary_table))[0]
+        return source_summary
+
+    def _validate_columns_in_preprocessed_table(self, table_name, col_names):
+        for col in col_names:
+            _assert(is_var_valid(table_name, col),
+                    "{module_name}: invalid column name "
+                    "('{col}') for table ({table_name}). "
+                    "Please ensure that the input table ({table_name}) "
+                    "has been preprocessed.".format(
                         module_name=self.module_name,
-                        dist_key_col=DISTRIBUTION_KEY_COLNAME,
-                        table=table))
+                        **locals()))
 
     def _is_valid_metrics_compute_frequency(self):
         return self.metrics_compute_frequency is None or \
@@ -389,6 +390,8 @@ class FitCommonValidator(object):
         if self.validation_table and self.validation_table.strip() != '':
             input_tbl_valid(self.validation_table, self.module_name)
             self._validate_input_table(self.validation_table, True)
+            validation_summary_table = add_postfix(self.validation_table, "_summary")
+            input_tbl_valid(validation_summary_table, self.module_name)
             for i in self.val_dep_var:
                 dependent_vartype = get_expr_type(i,
                                                   self.validation_table)
@@ -403,71 +406,53 @@ class FitCommonValidator(object):
                                input_shape, 2, True)
         if self.validation_table:
             InputValidator.validate_input_shape(
-                self.validation_table, self.independent_varname,
+                self.validation_table,  self.independent_varname,
                 input_shape, 2, True)
 
 
 class FitInputValidator(FitCommonValidator):
     def __init__(self, source_table, validation_table, output_model_table,
-                 model_arch_table, model_id, dependent_varname,
-                 independent_varname, dep_shape_cols, ind_shape_cols, num_iterations,
+                 model_arch_table, model_id, num_iterations,
                  metrics_compute_frequency, warm_start,
-                 use_gpus, accessible_gpus_for_seg, object_table, val_dep_var, val_ind_var):
+                 use_gpus, accessible_gpus_for_seg, object_table):
 
         self.module_name = 'madlib_keras_fit'
         super(FitInputValidator, self).__init__(source_table,
                                                 validation_table,
                                                 output_model_table,
-                                                model_arch_table,
-                                                model_id,
-                                                dependent_varname,
-                                                independent_varname,
-                                                dep_shape_cols,
-                                                ind_shape_cols,
                                                 num_iterations,
                                                 metrics_compute_frequency,
                                                 warm_start,
                                                 use_gpus,
                                                 accessible_gpus_for_seg,
                                                 self.module_name,
-                                                object_table,
-                                                val_dep_var,
-                                                val_ind_var)
-        InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
-            self.model_id)
+                                                object_table
+                                                )
+        InputValidator.validate_model_arch_table(self.module_name, model_arch_table,
+            model_id)
 
 class FitMultipleInputValidator(FitCommonValidator):
     def __init__(self, source_table, validation_table, output_model_table,
-                 model_selection_table, model_selection_summary_table, dependent_varname,
-                 independent_varname, dep_shape_cols, ind_shape_cols,
-                 num_iterations, model_info_table, mst_key_col,
+                 model_selection_table, num_iterations, mst_key_col,
                  model_arch_table_col, metrics_compute_frequency, warm_start,
-                 use_gpus, accessible_gpus_for_seg, val_dep_var, val_ind_var):
+                 use_gpus, accessible_gpus_for_seg):
 
         self.module_name = 'madlib_keras_fit_multiple'
-
         input_tbl_valid(model_selection_table, self.module_name)
-        input_tbl_valid(model_selection_summary_table, self.module_name,
+        self.model_selection_summary_table = add_postfix(model_selection_table,
+                                                         '_summary')
+        input_tbl_valid(self.model_selection_summary_table, self.module_name,
                         error_suffix_str="Please ensure that the model selection table ({0}) "
                                          "has been created by "
                                          "load_model_selection_table().".format(
                                             model_selection_table))
         self.msts, self.model_arch_table, self.object_table = query_model_configs(
-            model_selection_table, model_selection_summary_table,
+            model_selection_table, self.model_selection_summary_table,
             mst_key_col, model_arch_table_col)
-        if warm_start:
-            input_tbl_valid(model_info_table, self.module_name)
-        else:
-            output_tbl_valid(model_info_table, self.module_name)
+        input_tbl_valid(self.model_arch_table, self.module_name)
         super(FitMultipleInputValidator, self).__init__(source_table,
                                                         validation_table,
                                                         output_model_table,
-                                                        self.model_arch_table,
-                                                        None,
-                                                        dependent_varname,
-                                                        independent_varname,
-                                                        dep_shape_cols,
-                                                        ind_shape_cols,
                                                         num_iterations,
                                                         metrics_compute_frequency,
                                                         warm_start,
@@ -477,6 +462,13 @@ class FitMultipleInputValidator(FitCommonValidator):
                                                         self.object_table,
                                                         val_dep_var,
                                                         val_ind_var)
+        self.output_model_info_table = add_postfix(output_model_table,
+                                                   '_info')
+
+        if warm_start:
+            input_tbl_valid(self.output_model_info_table, self.module_name)
+        else:
+            output_tbl_valid(self.output_model_info_table, self.module_name)
 
 class MstLoaderInputValidator():
     def __init__(self,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
index cdda44a..5eed811 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
@@ -61,6 +61,15 @@ SELECT assert(trap_error($TRAP$
     SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', FALSE ,1);
     $TRAP$) = 1, 'Should error out if mst_key is given for non-multi model tables');
 
+DROP TABLE IF EXISTS evaluate_out;
+SELECT assert(test_input_table($test$SELECT madlib_keras_evaluate(
+    NULL, 'cifar_10_sample_val', 'evaluate_out', FALSE)$test$),
+    'Failed to assert the correct error message for null source table');
+
+SELECT assert(test_input_table($test$SELECT madlib_keras_evaluate(
+    'keras_saved_out', NULL, 'evaluate_out', FALSE)$test$),
+    'Failed to assert the correct error message for null source table');
+
 -- Test that evaluate errors out correctly if model_arch field missing from fit output
 DROP TABLE IF EXISTS evaluate_out;
 ALTER TABLE keras_saved_out DROP COLUMN model_arch;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index 988d1f3..eaa6916 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -30,6 +30,48 @@
 )
 
 m4_include(`SQLCommon.m4')
+SELECT assert(test_output_table($test$SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    NULL,
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3)$test$), 'Failed to assert the correct error message for null output table');
+
+SELECT assert(test_input_table($test$SELECT madlib_keras_fit(
+    NULL,
+    'keras_saved_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    NULL,
+    'cifar_10_sample_val')$test$), 'Failed to assert the correct error message for null source table');
+
+SELECT assert(test_input_table($test$SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    NULL,
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    NULL,
+    'cifar_10_sample_val')$test$), 'Failed to assert the correct error message for null model arch table');
+
+SELECT assert(test_error_msg($test$SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    NULL,
+    'table_does_not_exist')$test$, $test$'table_does_not_exist' does not exist$test$
+    ), 'Failed to assert the correct error message for non existing validation table');
 
 -- Please do not break up the compile_params string
 -- It might break the assertion
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 81554d3..2946184 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -342,6 +342,43 @@ SELECT load_model_selection_table(
         $$batch_size=32, epochs=1$$
     ]
 );
+----------- NULL input and output table validation
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT assert(test_input_table($test$SELECT madlib_keras_fit_multiple_model(
+	NULL,
+	'iris_multiple_model',
+	'mst_table_4row',
+	1,
+	FALSE
+);$test$), 'Failed to assert the correct error message for null source table');
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT assert(test_output_table($test$SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	NULL,
+	'mst_table_4row',
+	1,
+	FALSE
+);$test$), 'Failed to assert the correct error message for null output table');
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT assert(test_input_table($test$SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	NULL,
+	1,
+	FALSE
+);$test$), 'Failed to assert the correct error message for null mst table');
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT assert(test_error_msg($test$SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table_4row',
+	1,
+    FALSE,
+	'table_does_not_exist'
+);$test$, $test$'table_does_not_exist' does not exist$test$), 'Failed to assert the correct error message for non existing validation table');
 
 -- Test for one-hot encoded input data
 CREATE OR REPLACE FUNCTION test_fit_multiple_one_hot_encoded_input(caching boolean)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_multi_io.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_multi_io.sql_in
index 4afc47d..0c00851 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_multi_io.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_multi_io.sql_in
@@ -119,3 +119,28 @@ SELECT madlib_keras_fit(
     'test_custom_function_table'
 );
 
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+-- Multiple models test
+DROP TABLE IF EXISTS mst_table_1row, mst_table_1row_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table_1row',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$
+    ]
+);
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT assert(test_error_msg($test$SELECT madlib_keras_fit_multiple_model(
+	'iris_mult_packed',
+	'iris_model',
+	'mst_table_1row',
+	3,
+	FALSE)$test$, 'Multiple dependent and independent variables not supported'),
+	'Failed to assert the correct error message for multi-io not supported');
+!>)
+
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
index 82db074..9994739 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
@@ -66,6 +66,26 @@ SELECT assert(class_value IN ('0','1'),
     'Predicted value not in set of defined class values for model')
 FROM cifar10_predict;
 
+-- Test for null source table and null output table
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT assert(test_input_table($test$SELECT madlib_keras_predict(
+    NULL,
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    NULL,
+    FALSE)$test$), 'Failed to assert the correct error message for null model table');
+
+SELECT assert(test_input_table($test$SELECT madlib_keras_predict(
+    'keras_saved_out',
+    NULL,
+    'id',
+    'x',
+    'cifar10_predict',
+    NULL,
+    FALSE)$test$), 'Failed to assert the correct error message for null test table');
+
 DROP TABLE IF EXISTS cifar10_predict;
 SELECT assert(trap_error($TRAP$SELECT madlib_keras_predict(
     'keras_saved_out',
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
index 5fcee51..6f258cd 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
@@ -67,6 +67,33 @@ SELECT assert(
 FROM iris_predict AS p0,  iris_predict_byom AS p1
 WHERE p0.id=p1.id;
 
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT assert(test_input_table($test$SELECT madlib_keras_predict_byom(
+     'iris_model_arch',
+     2,
+     NULL,
+     'id',
+     'attributes',
+     'iris_predict_byom',
+     'response',
+     NULL,
+     ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor',
+      'Iris-virginica']::text[]]
+     )$test$), 'Failed to assert the correct error message for null test table');
+
+SELECT assert(test_input_table($test$SELECT madlib_keras_predict_byom(
+     NULL,
+     2,
+     'iris_test',
+     'id',
+     'attributes',
+     'iris_predict_byom',
+     'response',
+     NULL,
+     ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor',
+      'Iris-virginica']::text[]]
+     )$test$), 'Failed to assert the correct error message for null model table');
+
 -- class_values NULL, pred_type is NULL (response)
 DROP TABLE IF EXISTS iris_predict_byom;
 SELECT madlib_keras_predict_byom(
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index bb40fba..928b753 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -867,6 +867,7 @@ class MadlibKerasPredictBYOMTestCase(unittest.TestCase):
         self.module.InputValidator.validate_predict_byom_tables = Mock()
         self.module.InputValidator.validate_input_shape = Mock()
         self.module.BasePredict.call_internal_keras = Mock()
+        self.module.input_tbl_valid = Mock()
 
     def tearDown(self):
         self.module_patcher.stop()
@@ -1278,9 +1279,11 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
         self.module_patcher.start()
         import madlib_keras_validator
         self.subject = madlib_keras_validator
-        self.subject.FitCommonValidator._validate_common_args = Mock()
-        self.dep_shape_cols = [[10,1,1,1]]
-        self.ind_shape_cols = [[10,2]]
+        self.subject.FitCommonValidator._validate_tables = Mock()
+        self.subject.FitCommonValidator.get_source_summary_table_dict = \
+            Mock(return_value={'dependent_varname':['a'],
+                               'independent_varname':['b']})
+        self.subject.FitCommonValidator._validate_tables_schema = Mock()
 
     def tearDown(self):
         self.module_patcher.stop()
@@ -1288,34 +1291,26 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
 
     def test_is_valid_metrics_compute_frequency_True_None(self):
         obj = self.subject.FitCommonValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', self.dep_shape_cols,
-            self.ind_shape_cols, 5, None, False, False, [0],
-            'module_name', None, None, None)
+            'test_table', 'val_table', 'model_table', 5, None, False, False, [0],
+            'module_name', None)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_True_num(self):
         obj = self.subject.FitCommonValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', self.dep_shape_cols,
-            self.ind_shape_cols, 5, 3, False, False, [0],
-            'module_name', None, None, None)
+            'test_table', 'val_table', 'model_table', 5, 3, False, False, [0],
+            'module_name', None)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_zero(self):
         obj = self.subject.FitCommonValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', self.dep_shape_cols,
-            self.ind_shape_cols, 5, 0, False, False, [0],
-            'module_name', None, None, None)
+            'test_table', 'val_table', 'model_table', 5, 0, False, False, [0],
+            'module_name', None)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_greater(self):
         obj = self.subject.FitCommonValidator(
-            'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', self.dep_shape_cols,
-            self.ind_shape_cols, 5, 6, False, False, [0],
-            'module_name', None, None, None)
+            'test_table', 'val_table', 'model_table', 5, 6, False, False, [0],
+            'module_name', None)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
 
diff --git a/src/ports/postgres/modules/utilities/utilities.sql_in b/src/ports/postgres/modules/utilities/utilities.sql_in
index bbf861d..23abb40 100644
--- a/src/ports/postgres/modules/utilities/utilities.sql_in
+++ b/src/ports/postgres/modules/utilities/utilities.sql_in
@@ -542,6 +542,32 @@ BEGIN
 END;
 $$ LANGUAGE plpgsql;
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_error_msg(
+  stmt TEXT,
+  msg  TEXT
+)
+RETURNS BOOLEAN AS $$
+try:
+    plpy.execute(stmt)
+    return TRUE
+except Exception as ex:
+    return msg in ex.message
+$$ LANGUAGE plpythonu;
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_input_table(
+  stmt TEXT
+)
+RETURNS BOOLEAN AS $$
+SELECT MADLIB_SCHEMA.test_error_msg($1, 'NULL/empty input table name');
+$$ LANGUAGE SQL;
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_output_table(
+  stmt TEXT
+)
+RETURNS BOOLEAN AS $$
+SELECT MADLIB_SCHEMA.test_error_msg($1, 'NULL/empty output table name');
+$$ LANGUAGE SQL;
+
 -- A few of the gucs like plan_cache_mode and dev_opt_unsafe_truncate_in_subtransaction
 -- are only available in either > pg 11 or > gpdb 6.5. Using this function we
 -- can make sure to run the guc assertion test (assert_guc_value) on the correct