You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2020/10/22 15:04:32 UTC

[madlib] branch master updated: DL: Restrict access to the custom functions table

This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 0729a6f  DL: Restrict access to the custom functions table
0729a6f is described below

commit 0729a6f4d52855d08b8e882b65647152ea18fc96
Author: Orhan Kislal <ok...@apache.org>
AuthorDate: Thu Oct 22 13:50:21 2020 +0300

    DL: Restrict access to the custom functions table
    
    This commit ensures that any custom function table is cretated in
    the madlib schema. This table should be accessible by the admin only,
    which means the load_custom_function is only available to admins or
    superusers as well.
    
    Closes #520
---
 .../modules/deep_learning/madlib_keras.py_in       |   5 +-
 .../deep_learning/madlib_keras_automl.py_in        |  23 +++--
 .../madlib_keras_custom_function.py_in             | 107 ++++++++++++---------
 .../madlib_keras_custom_function.sql_in            |   4 +-
 .../madlib_keras_model_selection.py_in             |  16 +++
 .../deep_learning/madlib_keras_validator.py_in     |   2 +
 .../test/madlib_keras_custom_function.sql_in       |  26 +++++
 .../test/madlib_keras_model_averaging_e2e.sql_in   |   2 -
 .../test/madlib_keras_model_selection.sql_in       |   5 -
 .../test_madlib_keras_model_selection_table.py_in  |  39 +++++++-
 .../postgres/modules/utilities/utilities.py_in     |  19 ++++
 11 files changed, 186 insertions(+), 62 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 0d55028..d54deeb 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -43,6 +43,7 @@ from utilities.utilities import get_seg_number
 from utilities.utilities import madlib_version
 from utilities.utilities import unique_string
 from utilities.validate_args import get_expr_type
+from utilities.validate_args import quote_ident
 from utilities.control import MinWarning
 import tensorflow as tf
 
@@ -103,6 +104,9 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     else:
         accessible_gpus_for_seg = get_seg_number()*[0]
 
+    if object_table is not None:
+        object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
+
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table,
         model_id, mb_dep_var_col, mb_indep_var_col,
@@ -137,7 +141,6 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     # Compute total images on each segment
     dist_key_mapping, images_per_seg_train = get_image_count_per_seg_for_minibatched_data_from_db(source_table)
 
-
     if validation_table:
         seg_ids_val, images_per_seg_val = get_image_count_per_seg_for_minibatched_data_from_db(validation_table)
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
index d57a762..0f71fdf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
@@ -30,6 +30,7 @@ from madlib_keras_fit_multiple_model import FitMultipleModel
 from madlib_keras_model_selection import MstSearch, ModelSelectionSchema
 from keras_model_arch_table import ModelArchSchema
 from utilities.validate_args import table_exists, drop_tables
+from utilities.validate_args import quote_ident
 
 
 class AutoMLSchema:
@@ -183,7 +184,11 @@ class KerasAutoML():
         self.compile_params_grid = compile_params_grid
         self.fit_params_grid = fit_params_grid
 
+        if object_table is not None:
+            object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
+
         MstLoaderInputValidator(
+            schema_madlib=self.schema_madlib,
             model_arch_table=self.model_arch_table,
             model_selection_table=self.model_selection_table,
             model_selection_summary_table=self.model_selection_summary_table,
@@ -303,9 +308,15 @@ class KerasAutoML():
             r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
             initial_vals[s] = (n, int(round(r)))
         self.start_training_time = self.get_current_timestamp()
-        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
-                                  self.compile_params_grid, self.fit_params_grid, 'random',
-                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+        random_search = MstSearch(self.schema_madlib,
+                                  self.model_arch_table,
+                                  self.model_selection_table,
+                                  self.model_id_list,
+                                  self.compile_params_grid,
+                                  self.fit_params_grid,
+                                  'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]),
+                                  self.random_state,
                                   self.object_table)
         random_search.load() # for populating mst tables
 
@@ -535,7 +546,7 @@ class KerasAutoML():
                     $MAD${self.random_state}$MAD$::TEXT AS random_state,
                     $MAD${self.object_table}$MAD$::TEXT AS object_table,
                     {self.use_gpus} AS use_gpus,
-                    (SELECT metrics_compute_frequency FROM {model_training.model_summary_table})::INTEGER 
+                    (SELECT metrics_compute_frequency FROM {model_training.model_summary_table})::INTEGER
                     AS metrics_compute_frequency,
                     $MAD${self.name}$MAD$::TEXT AS name,
                     $MAD${self.description}$MAD$::TEXT AS description,
@@ -544,9 +555,9 @@ class KerasAutoML():
                     (SELECT madlib_version FROM {model_training.model_summary_table}) AS madlib_version,
                     (SELECT num_classes FROM {model_training.model_summary_table})::INTEGER AS num_classes,
                     (SELECT class_values FROM {model_training.model_summary_table}) AS class_values,
-                    (SELECT dependent_vartype FROM {model_training.model_summary_table}) 
+                    (SELECT dependent_vartype FROM {model_training.model_summary_table})
                     AS dependent_vartype,
-                    (SELECT normalizing_const FROM {model_training.model_summary_table}) 
+                    (SELECT normalizing_const FROM {model_training.model_summary_table})
                     AS normalizing_const
             """.format(self=self, model_training=model_training))
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index e500970..1ebf9f6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -21,9 +21,13 @@ import plpy
 from utilities.control import MinWarning
 from utilities.utilities import _assert
 from utilities.utilities import get_col_name_type_sql_string
+from utilities.utilities import current_user
+from utilities.utilities import is_superuser
+from utilities.utilities import get_schema
 from utilities.validate_args import columns_missing_from_table
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import quote_ident
+from utilities.validate_args import unquote_ident
 from utilities.validate_args import table_exists
 
 module_name = 'Keras Custom Function'
@@ -60,46 +64,54 @@ def _validate_object(object, **kwargs):
     except Exception as e:
         plpy.error("{0}: Invalid function object".format(module_name, e))
 
-def load_custom_function(object_table, object, name, description=None, **kwargs):
-    object_table = quote_ident(object_table)
+@MinWarning("error")
+def load_custom_function(schema_madlib, object_table, object, name, description=None, **kwargs):
+
+    if object_table is not None:
+        object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
     _validate_object(object)
     _assert(name is not None,
             "{0}: function name cannot be NULL!".format(module_name))
-    if not table_exists(object_table):
-        col_defs = get_col_name_type_sql_string(CustomFunctionSchema.col_names,
-                                                CustomFunctionSchema.col_types)
-
-        sql = "CREATE TABLE {0} ({1}, PRIMARY KEY({2}))" \
-            .format(object_table, col_defs, CustomFunctionSchema.FN_NAME)
-
-        plpy.execute(sql, 0)
-        # Using plpy.notice here as this function can be called:
-        # 1. Directly by the user, we do want to display to the user
-        #    if we create a new table or later the function name that
-        #    is added to the table
-        # 2. From load_top_k_accuracy_function, since plpy.info
-        #    displays the query context when called from the function
-        #    there is a very verbose output and cannot be suppressed with
-        #    MinWarning decorator as INFO is always displayed irrespective
-        #    of what the decorator sets the client_min_messages to.
-        #    Therefore, instead we print this information as a NOTICE
-        #    when called directly by the user and suppress it by setting
-        #    MinWarning decorator to 'error' level in the calling function.
-        plpy.notice("{0}: Created new custom function table {1}." \
-                  .format(module_name, object_table))
-    else:
-        missing_cols = columns_missing_from_table(object_table,
-                                                  CustomFunctionSchema.col_names)
-        if len(missing_cols) > 0:
-            plpy.error("{0}: Invalid custom function table {1},"
-                       " missing columns: {2}".format(module_name,
-                                                      object_table,
-                                                      missing_cols))
-
-    insert_query = plpy.prepare("INSERT INTO {0} "
-                                "VALUES(DEFAULT, $1, $2, $3);".format(object_table),
-                                CustomFunctionSchema.col_types[1:])
+    _assert(is_superuser(current_user()), "DL: The user has to have admin "\
+        "privilages to load a custom function")
     try:
+        if not table_exists(object_table):
+            col_defs = get_col_name_type_sql_string(CustomFunctionSchema.col_names,
+                                                    CustomFunctionSchema.col_types)
+
+            sql = """CREATE TABLE {object_table}
+                                  ({col_defs}, PRIMARY KEY({fn_name}))
+                """.format(fn_name=CustomFunctionSchema.FN_NAME,**locals())
+
+            plpy.execute(sql, 0)
+            # Using plpy.notice here as this function can be called:
+            # 1. Directly by the user, we do want to display to the user
+            #    if we create a new table or later the function name that
+            #    is added to the table
+            # 2. From load_top_k_accuracy_function, since plpy.info
+            #    displays the query context when called from the function
+            #    there is a very verbose output and cannot be suppressed with
+            #    MinWarning decorator as INFO is always displayed irrespective
+            #    of what the decorator sets the client_min_messages to.
+            #    Therefore, instead we print this information as a NOTICE
+            #    when called directly by the user and suppress it by setting
+            #    MinWarning decorator to 'error' level in the calling function.
+            plpy.notice("{0}: Created new custom function table {1}." \
+                      .format(module_name, object_table))
+            plpy.execute("GRANT SELECT ON {0} TO PUBLIC".format(object_table))
+        else:
+            missing_cols = columns_missing_from_table(object_table,
+                                                      CustomFunctionSchema.col_names)
+            if len(missing_cols) > 0:
+                plpy.error("{0}: Invalid custom function table {1},"
+                           " missing columns: {2}".format(module_name,
+                                                          object_table,
+                                                          missing_cols))
+
+        insert_query = plpy.prepare("INSERT INTO {object_table} "
+                                    "VALUES(DEFAULT, $1, $2, $3);".format(**locals()),
+                                    CustomFunctionSchema.col_types[1:])
+
         plpy.execute(insert_query,[name, description, object], 0)
     # spiexceptions.UniqueViolation is only supported for PG>=9.2. For
     # GP5(based of PG8.4) it cannot be used. Therefore, checking exception
@@ -112,8 +124,16 @@ def load_custom_function(object_table, object, name, description=None, **kwargs)
     plpy.notice("{0}: Added function {1} to {2} table".
               format(module_name, name, object_table))
 
-def delete_custom_function(object_table, id=None, name=None, **kwargs):
-    object_table = quote_ident(object_table)
+@MinWarning("error")
+def delete_custom_function(schema_madlib, object_table, id=None, name=None, **kwargs):
+
+    if object_table is not None:
+        schema_name = get_schema(object_table)
+        if schema_name is None:
+            object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
+        elif schema_name != schema_madlib:
+            plpy.error("DL: Custom function table has to be in the {0} schema".format(schema_madlib))
+
     input_tbl_valid(object_table, "Keras Custom Funtion")
     _assert(id is not None or name is not None,
             "{0}: function id/name cannot be NULL! " \
@@ -127,12 +147,12 @@ def delete_custom_function(object_table, id=None, name=None, **kwargs):
 
     if id is not None:
         sql = """
-               DELETE FROM {0} WHERE {1}={2}
-              """.format(object_table, CustomFunctionSchema.FN_ID, id)
+               DELETE FROM {object_table} WHERE {fn_id}={id}
+              """.format(fn_id=CustomFunctionSchema.FN_ID,**locals())
     else:
         sql = """
-               DELETE FROM {0} WHERE {1}=$${2}$$
-              """.format(object_table, CustomFunctionSchema.FN_NAME, name)
+               DELETE FROM {object_table} WHERE {fn_name}=$${name}$$
+              """.format(fn_name=CustomFunctionSchema.FN_NAME,**locals())
     res = plpy.execute(sql, 0)
 
     if res.nrows() > 0:
@@ -141,7 +161,8 @@ def delete_custom_function(object_table, id=None, name=None, **kwargs):
     else:
         plpy.error("{0}: Object id {1} not found".format(module_name, id))
 
-    sql = "SELECT {0} FROM {1}".format(CustomFunctionSchema.FN_ID, object_table)
+    sql = "SELECT {0} FROM {1}".format(
+        CustomFunctionSchema.FN_ID, object_table)
     res = plpy.execute(sql, 0)
     if not res:
         plpy.notice("{0}: Dropping empty custom keras function table " \
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
index bb9864d..43f6afc 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
@@ -374,7 +374,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_custom_function(
 RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning',`madlib_keras_custom_function')
     with AOControl(False):
-        madlib_keras_custom_function.delete_custom_function(object_table, id=id)
+        madlib_keras_custom_function.delete_custom_function(schema_madlib, object_table, id=id)
 $$ LANGUAGE plpythonu VOLATILE;
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_custom_function(
@@ -384,7 +384,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_custom_function(
 RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning',`madlib_keras_custom_function')
     with AOControl(False):
-        madlib_keras_custom_function.delete_custom_function(object_table, name=name)
+        madlib_keras_custom_function.delete_custom_function(schema_madlib, object_table, name=name)
 $$ LANGUAGE plpythonu VOLATILE;
 
 -- Functions for online help
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
index 9d9fa60..99c7150 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
@@ -33,6 +33,7 @@ from madlib_keras_wrapper import parse_and_validate_fit_params
 from madlib_keras_wrapper import parse_and_validate_compile_params
 from utilities.control import MinWarning
 from utilities.utilities import add_postfix, extract_keyvalue_params, _assert, _assert_equal
+from utilities.utilities import quote_ident, get_schema
 from utilities.validate_args import table_exists, drop_tables
 
 class ModelSelectionSchema:
@@ -65,6 +66,7 @@ class MstLoader():
     """
 
     def __init__(self,
+                 schema_madlib,
                  model_arch_table,
                  model_selection_table,
                  model_id_list,
@@ -73,13 +75,18 @@ class MstLoader():
                  object_table=None,
                  **kwargs):
 
+        self.schema_madlib = schema_madlib
         self.model_arch_table = model_arch_table
         self.model_selection_table = model_selection_table
         self.model_selection_summary_table = add_postfix(
             model_selection_table, "_summary")
         self.model_id_list = sorted(list(set(model_id_list)))
+        if object_table is not None:
+            object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
         self.object_table = object_table
+
         MstLoaderInputValidator(
+            schema_madlib=self.schema_madlib,
             model_arch_table=self.model_arch_table,
             model_selection_table=self.model_selection_table,
             model_selection_summary_table=self.model_selection_summary_table,
@@ -243,6 +250,7 @@ class MstSearch():
     """
 
     def __init__(self,
+                 schema_madlib,
                  model_arch_table,
                  model_selection_table,
                  model_id_list,
@@ -254,13 +262,21 @@ class MstSearch():
                  object_table=None,
                  **kwargs):
 
+        self.schema_madlib = schema_madlib
         self.model_arch_table = model_arch_table
         self.model_selection_table = model_selection_table
         self.model_selection_summary_table = add_postfix(
             model_selection_table, "_summary")
         self.model_id_list = sorted(list(set(model_id_list)))
 
+        if object_table is not None:
+            schema_name = get_schema(object_table)
+            if schema_name is None:
+                object_table = "{0}.{1}".format(schema_madlib, quote_ident(object_table))
+            elif schema_name != schema_madlib:
+                plpy.error("DL: Custom function table has to be in the {0} schema".format(schema_madlib))
         MstLoaderInputValidator(
+            schema_madlib=self.schema_madlib,
             model_arch_table=self.model_arch_table,
             model_selection_table=self.model_selection_table,
             model_selection_summary_table=self.model_selection_summary_table,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index fac7357..9382407 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -459,6 +459,7 @@ class FitMultipleInputValidator(FitCommonValidator):
 
 class MstLoaderInputValidator():
     def __init__(self,
+                 schema_madlib,
                  model_arch_table,
                  model_selection_table,
                  model_selection_summary_table,
@@ -468,6 +469,7 @@ class MstLoaderInputValidator():
                  object_table,
                  module_name='load_model_selection_table'
                  ):
+        self.schema_madlib = schema_madlib
         self.model_arch_table = model_arch_table
         self.model_selection_table = model_selection_table
         self.model_selection_summary_table = model_selection_summary_table
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index 520b9c9..d9a323a 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -158,3 +158,29 @@ FROM __test_custom_function_table__ WHERE id = 3;
 
 SELECT assert(name = 'top_8_accuracy', 'Top 8 accuracy name is incorrect')
 FROM __test_custom_function_table__ WHERE id = 4;
+
+CREATE SCHEMA MADLIB_SCHEMA_aaa;
+CREATE TABLE pg_temp.temp1 AS SELECT * FROM MADLIB_SCHEMA.test_custom_function_table;
+CREATE TABLE pg_temp.MADLIB_SCHEMA AS SELECT * FROM MADLIB_SCHEMA.test_custom_function_table;
+CREATE TABLE MADLIB_SCHEMA_aaa.test_table AS SELECT * FROM MADLIB_SCHEMA.test_custom_function_table;
+
+SELECT assert(MADLIB_SCHEMA.trap_error($$
+  SELECT load_custom_function('pg_temp.temp1', custom_function_object(), 'sum_fn', 'returns sum');
+$$) = 1, 'Cannot use non-madlib schemas');
+
+SELECT assert(MADLIB_SCHEMA.trap_error($$
+  SELECT load_custom_function('test_custom_function_table UNION pg_temp.temp1',
+    custom_function_object(), 'sum_fn', 'returns sum');
+$$) = 1, 'UNION should not pass');
+
+SELECT assert(MADLIB_SCHEMA.trap_error($$
+  SELECT load_custom_function('pg_temp.MADLIB_SCHEMA', custom_function_object(), 'sum_fn', 'returns sum');
+$$) = 1, 'pg_temp.MADLIB_SCHEMA should not pass');
+
+SELECT assert(MADLIB_SCHEMA.trap_error($$
+  SELECT load_custom_function('MADLIB_SCHEMA_aaa.test_table', custom_function_object(), 'sum_fn', 'returns sum');
+$$) = 1, 'test_schema.MADLIB_SCHEMA should not pass');
+
+DROP SCHEMA MADLIB_SCHEMA_aaa CASCADE;
+DROP TABLE IF EXISTS pg_temp.temp1;
+DROP TABLE IF EXISTS pg_temp.MADLIB_SCHEMA;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index b002550..d6bcae7 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -145,7 +145,6 @@ FROM evaluate_out;
 
 -- TEST custom loss function
 
-
 DROP TABLE IF EXISTS test_custom_function_table;
 SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
 SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
@@ -202,7 +201,6 @@ SELECT assert(
         pg_typeof(normalizing_const) = 'real'::regtype AND
         name is NULL AND
         description is NULL AND
-        object_table = 'test_custom_function_table' AND
         model_size > 0 AND
         madlib_version is NOT NULL AND
         compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text AND
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 0c29246..49b6940 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -424,11 +424,6 @@ SELECT load_model_selection_table(
     'test_custom_function_table'
 );
 
-SELECT assert(
-        object_table = 'test_custom_function_table',
-        'Keras Fit Multiple Output Summary Validation failed when user passes in object_table. Actual:' || __to_char(summary))
-FROM (SELECT * FROM mst_object_table_summary) summary;
-
 -- Test when number of configs(3) equals number of segments(3)
 CREATE OR REPLACE FUNCTION test_fit_multiple_equal_configs(caching boolean)
 RETURNS VOID AS
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
index dbd29ed..1a2f61f 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
@@ -51,12 +51,13 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
         self.module.MstLoaderInputValidator = MagicMock()
 
         self.subject = self.module.MstSearch
+        self.madlib_schema = 'mad'
         self.model_selection_table = 'mst_table'
         self.model_arch_table = 'model_arch_library'
         self.model_id_list = [1, 2]
         self.compile_params_grid = """
-            {'loss': ['categorical_crossentropy'], 
-            'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1]} ], 
+            {'loss': ['categorical_crossentropy'],
+            'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1]} ],
             'metrics': ['accuracy']}
         """
         self.fit_params_grid = """
@@ -69,6 +70,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
 
     def test_mst_table_dimension(self):
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -78,6 +80,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
         self.assertEqual(32, len(generate_mst.msts))
 
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -90,6 +93,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
         self.assertEqual(9, len(generate_mst.msts))
 
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -102,11 +106,12 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
         self.assertEqual(9, len(generate_mst.msts))
 
         self.compile_params_grid = """
-            {'loss': ['categorical_crossentropy'],   
+            {'loss': ['categorical_crossentropy'],
             'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1]} ],
             'metrics': ['accuracy']}
         """
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -121,6 +126,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
     def test_invalid_input_args(self):
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_id_list,
@@ -131,6 +137,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             )
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_id_list,
@@ -143,6 +150,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             )
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_id_list,
@@ -152,6 +160,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             )
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_id_list,
@@ -163,6 +172,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             )
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 [-3],
@@ -184,6 +194,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
 
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_id_list,
@@ -194,6 +205,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
     def test_duplicate_params(self):
         self.model_id_list = [1, 1, 2]
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -202,6 +214,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
         )
         self.assertEqual(32, len(generate_mst.msts))
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -222,6 +235,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
                                    "'metrics': ['accuracy']}"
 
         generate_mst1 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -232,6 +246,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
 
     def test_output_types(self):
         generate_mst1 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -243,6 +258,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             self.assertEqual("loss='categorical_crossentropy'" in d1['compile_params'], True)
 
         generate_mst2 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -257,6 +273,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
 
     def test_seed_result_reproducibility(self):
         generate_mst1 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -267,6 +284,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             47
         )
         generate_mst2 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -277,6 +295,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             47
         )
         generate_mst3 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -296,6 +315,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             'metrics': ['accuracy']}
         """
         generate_mst1 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -313,6 +333,7 @@ class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
             'metrics': ['accuracy']}
         """
         generate_mst2 = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -347,6 +368,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
             MagicMock()
 
         self.subject = self.module.MstLoader
+        self.madlib_schema = 'mad'
         self.model_selection_table = 'mst_table'
         self.model_arch_table = 'model_arch_library'
         self.object_table = 'custom_function_table'
@@ -375,6 +397,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
 
     def test_mst_table_dimension(self):
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -389,6 +412,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
             .side_effect = plpy.PLPYException('Invalid input args')
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_id_list,
@@ -402,6 +426,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
             .side_effect = plpy.PLPYException('Invalid input args')
         with self.assertRaises(plpy.PLPYException):
             generate_mst = self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_id_list,
@@ -435,6 +460,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
             "batch_size=10,epochs =1"
         ]
         generate_mst = self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_id_list,
@@ -464,6 +490,7 @@ class MstLoaderInputValidatorTestCase(unittest.TestCase):
         self.module = deep_learning.madlib_keras_validator
 
         self.subject = self.module.MstLoaderInputValidator
+        self.madlib_schema = 'mad'
         self.model_selection_table = 'mst_table'
         self.model_arch_table = 'model_arch_library'
         self.model_arch_summary_table = 'model_arch_library_summary'
@@ -497,6 +524,7 @@ class MstLoaderInputValidatorTestCase(unittest.TestCase):
         self.subject.parse_and_validate_fit_params = Mock()
 
         self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_arch_summary_table,
@@ -513,6 +541,7 @@ class MstLoaderInputValidatorTestCase(unittest.TestCase):
         self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
                                                {'name': 'custom_fn2'}]]
         self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_arch_summary_table,
@@ -536,6 +565,7 @@ class MstLoaderInputValidatorTestCase(unittest.TestCase):
             """
         ]
         self.subject(
+            self.madlib_schema,
             self.model_selection_table,
             self.model_arch_table,
             self.model_arch_summary_table,
@@ -561,6 +591,7 @@ class MstLoaderInputValidatorTestCase(unittest.TestCase):
 
         with self.assertRaises(plpy.PLPYException) as error:
             self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_arch_summary_table,
@@ -586,6 +617,7 @@ class MstLoaderInputValidatorTestCase(unittest.TestCase):
         ]
         with self.assertRaises(plpy.PLPYException) as error:
             self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_arch_summary_table,
@@ -612,6 +644,7 @@ class MstLoaderInputValidatorTestCase(unittest.TestCase):
         ]
         with self.assertRaises(plpy.PLPYException) as error:
             self.subject(
+                self.madlib_schema,
                 self.model_selection_table,
                 self.model_arch_table,
                 self.model_arch_summary_table,
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 946b139..b228d68 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -15,6 +15,7 @@ from validate_args import input_tbl_valid
 from validate_args import is_var_valid
 from validate_args import output_tbl_valid
 from validate_args import quote_ident
+from validate_args import unquote_ident
 from validate_args import drop_tables
 import plpy
 
@@ -788,6 +789,10 @@ def current_user():
     return plpy.execute("SELECT current_user")[0]['current_user']
 # ------------------------------------------------------------------------
 
+def is_superuser(user):
+
+    return plpy.execute("SELECT rolsuper FROM pg_catalog.pg_roles "\
+                        "WHERE rolname = '{0}'".format(user))[0]['rolsuper']
 
 def madlib_version(schema_madlib):
     """Returns the MADlib version string."""
@@ -1298,3 +1303,17 @@ def get_psql_type(py_type):
         return 'varchar'
     else:
         plpy.error("Cannot determine the type of {0}".format(py_type))
+
+
+def get_schema(tbl_str):
+
+    names = tbl_str.split('.')
+
+    if not names or len(names) > 2:
+        raise TypeError("Incorrect table name ({0}) provided! Table name should be "
+                        "of the form: <schema name>.<table name>".format(table_name))
+    elif len(names) == 2:
+        return unquote_ident(names[0])
+
+    else:
+        return None