You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/08/28 01:14:16 UTC

[GitHub] [madlib] khannaekta commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r478773154



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -17,17 +17,35 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from datetime import datetime
 import plpy
 import math
+from time import time
 
-from utilities.utilities import _assert
+from madlib_keras_validator import MstLoaderInputValidator
+from utilities.utilities import add_postfix, extract_keyvalue_params, _assert, _assert_equal
 from utilities.control import MinWarning
+from madlib_keras_fit_multiple_model import FitMultipleModel
+from madlib_keras_model_selection import MstSearch, ModelSelectionSchema
+from keras_model_arch_table import ModelArchSchema
+from utilities.validate_args import table_exists, drop_tables
+
 
 class AutoMLSchema:
     BRACKET = 's'
     ROUND = 'i'
     CONFIGURATIONS = 'n_i'
     RESOURCES = 'r_i'
+    HYPERBAND = 'hyperband'
+    R = 'R'
+    ETA = 'eta'
+    SKIP_LAST = 'skip_last'
+    LOSS_METRIC = 'training_loss_final'
+    TEMP_MST_TABLE = 'temp_mst_table'
+    TEMP_MST_SUMMARY_TABLE = add_postfix(TEMP_MST_TABLE, '_summary')
+    TEMP_OUTPUT_TABLE = 'temp_output_table'

Review comment:
       unique_string('temp_output_table')

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="
+                     "t.model_weights FROM {model_training.original_model_output_table} t " \
+                     "WHERE a.mst_key=t.mst_key".format(self=self, model_training=model_training))
+
+        # inserts any newly trained configs
+        plpy.execute("INSERT INTO {self.model_output_table} SELECT * FROM {model_training.original_model_output_table} " \
+                     "WHERE {model_training.original_model_output_table}.mst_key NOT IN "
+                     "(SELECT mst_key FROM {self.model_output_table})".format(self=self,
+                                                                              model_training=model_training))
+
+    def update_model_output_info_table(self, i, model_training, initial_vals):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output info table.
+        :param i: outer diagonal loop iteration.
+        :param model_training: Fit Multiple function call object.
+        :param initial_vals: Dictionary of initial configurations and resources as part of the initial hyperband
+        schedule.
+        """
+        # normalizing factor for metrics_iters due to warm start
+        epochs_factor = sum([n[1] for n in initial_vals.values()][::-1][:i])
+        iters = plpy.execute("SELECT {AutoMLSchema.METRICS_ITERS} " \
+                             "FROM {model_training.model_summary_table}".format(AutoMLSchema=AutoMLSchema,
+                                                                                model_training=model_training))
+        metrics_iters_val = [epochs_factor+mi for mi in iters[0]['metrics_iters']]
+
+        # casting same metrics_iters values for the fit_multiple run with the chosen configs
+        # in order to update overall info table later
+        plpy.execute("ALTER TABLE {model_training.model_info_table} " \
+                     "ADD COLUMN {AutoMLSchema.METRICS_ITERS} INTEGER[]".format(model_training=model_training,
+                                                                                AutoMLSchema=AutoMLSchema))
+        plpy.execute("UPDATE {model_training.model_info_table} SET {AutoMLSchema.METRICS_ITERS} = " \
+                     "ARRAY{metrics_iters_val}::INTEGER[] FROM {model_training.model_summary_table}".format(
+            model_training=model_training, metrics_iters_val=metrics_iters_val, AutoMLSchema=AutoMLSchema))
+        validation_update_q = "validation_metrics_final=t.validation_metrics_final, " \
+                                     "validation_loss_final=t.validation_loss_final, " \
+                                     "validation_metrics=a.validation_metrics || t.validation_metrics, " \
+                                     "validation_loss=a.validation_loss || t.validation_loss, " \
+            if self.validation_table else ""
+
+        # updates train/val info for any previously trained configs
+        plpy.execute("UPDATE {self.model_info_table} a SET " \
+                     "metrics_elapsed_time=a.metrics_elapsed_time || t.metrics_elapsed_time, " \
+                     "training_metrics_final=t.training_metrics_final, " \
+                     "training_loss_final=t.training_loss_final, " \
+                     "training_metrics= a.training_metrics || t.training_metrics, " \
+                     "training_loss= a.training_loss || t.training_loss, ".format(self=self) + validation_update_q +
+                     "metrics_iters=a.metrics_iters || t.metrics_iters " \
+                     "FROM {model_training.model_info_table} t " \
+                     "WHERE a.mst_key=t.mst_key".format(model_training=model_training))
+
+        # inserts info about metrics and validation for newly trained model configs
+        plpy.execute("INSERT INTO {self.model_info_table} SELECT * FROM {model_training.model_info_table} " \
+                     "WHERE {model_training.model_info_table}.mst_key NOT IN "
+                     "(SELECT mst_key FROM {self.model_info_table})".format(self=self,
+                                                                            model_training=model_training))
+
+        # removes metrics_iters column from temp mst info table, for further use in warm started model training
+        plpy.execute("ALTER TABLE {model_training.model_info_table} " \
+                     "DROP COLUMN {AutoMLSchema.METRICS_ITERS}".format(model_training=model_training,
+                                                                       AutoMLSchema=AutoMLSchema))
+
+    def update_model_selection_table(self):
+        """
+        Drops and re-createst the mst table to only include the best performing model configuration.

Review comment:
       typo in re-createst

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -17,17 +17,35 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from datetime import datetime
 import plpy
 import math
+from time import time
 
-from utilities.utilities import _assert
+from madlib_keras_validator import MstLoaderInputValidator
+from utilities.utilities import add_postfix, extract_keyvalue_params, _assert, _assert_equal
 from utilities.control import MinWarning
+from madlib_keras_fit_multiple_model import FitMultipleModel
+from madlib_keras_model_selection import MstSearch, ModelSelectionSchema
+from keras_model_arch_table import ModelArchSchema
+from utilities.validate_args import table_exists, drop_tables
+
 
 class AutoMLSchema:
     BRACKET = 's'
     ROUND = 'i'
     CONFIGURATIONS = 'n_i'
     RESOURCES = 'r_i'
+    HYPERBAND = 'hyperband'
+    R = 'R'
+    ETA = 'eta'
+    SKIP_LAST = 'skip_last'
+    LOSS_METRIC = 'training_loss_final'
+    TEMP_MST_TABLE = 'temp_mst_table'

Review comment:
       we should use `unique_string('temp_mst_table')` instead to create a unique name.

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table
+                                       'class_text',         -- Dependent variable
+                                       'attributes'          -- Independent variable
+           );
+
+DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary;
+SELECT validation_preprocessor_dl('iris_test',          -- Source table
+                                         'iris_test_packed',   -- Output table

Review comment:
       indentation

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table

Review comment:
       indentation is off.

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table

Review comment:
       You can remove these call to training_preprocessor_dl/validation_preprocessor_dl as they are called in the setup file `madlib_keras_iris.setup.sql_in ` and you can use iris_data_packed as the input table
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="

Review comment:
       UPDATE is essentially DELETE/INSERT. In the past we saw that updating the table just doubles the size of the table since it is inside a UDF(single transaction). If this table is going to hold mst_key's >= the number of segments and update those many keys every diagonal that it trains, I think instead of calling update we should Truncate the table and Insert into it.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,

Review comment:
       The `FitMultipleModel` goes through additional validation logic of mst_table, output_table etc. for every diagonal that we are exploring, can we initialize the object(model_training) once and just call `model_training.run_training()` for each diagonal 

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="
+                     "t.model_weights FROM {model_training.original_model_output_table} t " \
+                     "WHERE a.mst_key=t.mst_key".format(self=self, model_training=model_training))
+
+        # inserts any newly trained configs
+        plpy.execute("INSERT INTO {self.model_output_table} SELECT * FROM {model_training.original_model_output_table} " \
+                     "WHERE {model_training.original_model_output_table}.mst_key NOT IN "
+                     "(SELECT mst_key FROM {self.model_output_table})".format(self=self,
+                                                                              model_training=model_training))
+
+    def update_model_output_info_table(self, i, model_training, initial_vals):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output info table.
+        :param i: outer diagonal loop iteration.
+        :param model_training: Fit Multiple function call object.
+        :param initial_vals: Dictionary of initial configurations and resources as part of the initial hyperband
+        schedule.
+        """
+        # normalizing factor for metrics_iters due to warm start
+        epochs_factor = sum([n[1] for n in initial_vals.values()][::-1][:i])
+        iters = plpy.execute("SELECT {AutoMLSchema.METRICS_ITERS} " \
+                             "FROM {model_training.model_summary_table}".format(AutoMLSchema=AutoMLSchema,
+                                                                                model_training=model_training))
+        metrics_iters_val = [epochs_factor+mi for mi in iters[0]['metrics_iters']]
+
+        # casting same metrics_iters values for the fit_multiple run with the chosen configs
+        # in order to update overall info table later
+        plpy.execute("ALTER TABLE {model_training.model_info_table} " \

Review comment:
       why do we alter table add/drop column every diagonal that we explore? Alter table rewrites the entire table.
   We should avoid it since we only care to add the metrics_iters values into `self.model_info_table`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org