You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/08/22 01:19:45 UTC

[GitHub] [madlib] Advitya17 opened a new pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Advitya17 opened a new pull request #513:
URL: https://github.com/apache/madlib/pull/513


   JIRA: MADLIB-{1447,1448,1449}
   
   We integrate AutoML capabilities in Apache MADlib by introducing a function called `madlib_keras_automl`, which bridges the worlds of setting and running model selection together, and helps automate and accelerate the model selection and training processes end-to-end. The user can declaratively specify the names of their train/val datasets, mst and output tables, model architecture and param grid details, the chosen method name and associated params, and various training details, and our API handles the scheduling and execution components with the algorithm workload info displayed to the user.
   
   The first AutoML algorithm we implement is Hyperband, a state-of-the-art hyperparameter optimization algorithm which speeds up random search with adaptive resource allocation, successive halving (SHA) and early stopping. This algorithm generates a schedule with user inputs and evaluates model configurations in a smarter, more efficient way by continually exploring more promising configurations. 
   
   In the case of MPP databases such as Greenplum, we further accelerate this algorithm by simultaneously evaluating multiple rounds of the algorithm located along a 'diagonal', to keep machines busy and take advantage of the large distributed storage and compute power offered by Greenplum.
   
   With the diagonal approach, we introduce some additional low-level optimizations with the implementation related to optimal runtimes and code quality by:
   
   1. Reducing number of random search function calls from `s_max+1` to just `1`.
   2. Reducing number of multiple model training function calls from `s_max(s_max+1)/2` to `s_max+1`.
   3. Reducing number of sampled SHA configuration groups from `s_max+1` to `s_max+1-skip_last` (i.e. only sampling the configurations actually needed for evaluation).
   
   Key:
   R --> maximum amount of resources/iterations that can be allocated to a single configuration in any particular round of Hyperband
   eta --> factor controlling the proportion of configs discarded in each round of SHA
   s_max = floor(log(R)/log(eta)) --> controls the number of SHA brackets (=s_max+1) executed with Hyperband
   skip_last --> Number of diagonals to skip at the end (to avoid running the most time/resource intensive bracket(s) and/or to avoid overfitting or loss in predictive power). skip_last ∈ [0, s_max]
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] Advitya17 edited a comment on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
Advitya17 edited a comment on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-688997702


   > @Advitya17 What is the syntax error here please?
   > 
   > ```
   > SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
   >                            'automl_cifar10_output', 
   >                            'model_arch_library', 
   >                            'automl_cifar_10_mst_table',
   >                            ARRAY[1], 
   >                            $${'loss': ['categorical_crossentropy'], 
   >                               'optimizer_params_list': [
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
   >                              'metrics':['accuracy']}$$, 
   >                            $${'batch_size': [64, 128], 'epochs': [1]}$$,
   >                            'hyperband', 
   >                            'R=9, eta=3, skip_last=2');
   > ```
   > 
   > ```
   > InternalError: (psycopg2.errors.InternalError_) SyntaxError: invalid syntax (<unknown>, line 1) (plpython.c:5038)
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_automl", line 21, in <module>
   >     schedule_loader = madlib_keras_automl.KerasAutoML(**globals())
   >   PL/Python function "madlib_keras_automl", line 42, in wrapper
   >   PL/Python function "madlib_keras_automl", line 215, in __init__
   >   PL/Python function "madlib_keras_automl", line 308, in find_hyperband_config
   >   PL/Python function "madlib_keras_automl", line 42, in wrapper
   >   PL/Python function "madlib_keras_automl", line 283, in __init__
   >   PL/Python function "madlib_keras_automl", line 48, in literal_eval
   >   PL/Python function "madlib_keras_automl", line 36, in parse
   > PL/Python function "madlib_keras_automl"
   > 
   > [SQL: SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
   >                            'automl_cifar10_output', 
   >                            'model_arch_library', 
   >                            'automl_cifar_10_mst_table',
   >                            ARRAY[1], 
   >                            $${'loss': ['categorical_crossentropy'], 
   >                               'optimizer_params_list': [
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
   >                              'metrics':['accuracy']}$$, 
   >                            $${'batch_size': [64, 128], 'epochs': [1]}$$,
   >                            'hyperband', 
   >                            'R=9, eta=3, skip_last=2');]
   > (Background on this error at: http://sqlalche.me/e/2j85)
   > ```
   
   - There needs to be a comma between the dictionaries specified inside optimizer_params_list
   - All the params (except optimizer) in the optimizer_params_list should be of the format [lower_bound, upper_bound, distribution_type] and you should specify `[0.0001, 0.001, 'log']` instead of `[0.001, 0.0001, 'log']`.
   
   Thanks for the question though. When I tried it with the right syntax and got the assertion error related to the format of specifying distribution with `[0.001, 0.0001, 'log']`, I saw the error message I threw wasn't completely intuitive, so I can modify its language a bit. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r478773154



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -17,17 +17,35 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from datetime import datetime
 import plpy
 import math
+from time import time
 
-from utilities.utilities import _assert
+from madlib_keras_validator import MstLoaderInputValidator
+from utilities.utilities import add_postfix, extract_keyvalue_params, _assert, _assert_equal
 from utilities.control import MinWarning
+from madlib_keras_fit_multiple_model import FitMultipleModel
+from madlib_keras_model_selection import MstSearch, ModelSelectionSchema
+from keras_model_arch_table import ModelArchSchema
+from utilities.validate_args import table_exists, drop_tables
+
 
 class AutoMLSchema:
     BRACKET = 's'
     ROUND = 'i'
     CONFIGURATIONS = 'n_i'
     RESOURCES = 'r_i'
+    HYPERBAND = 'hyperband'
+    R = 'R'
+    ETA = 'eta'
+    SKIP_LAST = 'skip_last'
+    LOSS_METRIC = 'training_loss_final'
+    TEMP_MST_TABLE = 'temp_mst_table'
+    TEMP_MST_SUMMARY_TABLE = add_postfix(TEMP_MST_TABLE, '_summary')
+    TEMP_OUTPUT_TABLE = 'temp_output_table'

Review comment:
       unique_string('temp_output_table')

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="
+                     "t.model_weights FROM {model_training.original_model_output_table} t " \
+                     "WHERE a.mst_key=t.mst_key".format(self=self, model_training=model_training))
+
+        # inserts any newly trained configs
+        plpy.execute("INSERT INTO {self.model_output_table} SELECT * FROM {model_training.original_model_output_table} " \
+                     "WHERE {model_training.original_model_output_table}.mst_key NOT IN "
+                     "(SELECT mst_key FROM {self.model_output_table})".format(self=self,
+                                                                              model_training=model_training))
+
+    def update_model_output_info_table(self, i, model_training, initial_vals):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output info table.
+        :param i: outer diagonal loop iteration.
+        :param model_training: Fit Multiple function call object.
+        :param initial_vals: Dictionary of initial configurations and resources as part of the initial hyperband
+        schedule.
+        """
+        # normalizing factor for metrics_iters due to warm start
+        epochs_factor = sum([n[1] for n in initial_vals.values()][::-1][:i])
+        iters = plpy.execute("SELECT {AutoMLSchema.METRICS_ITERS} " \
+                             "FROM {model_training.model_summary_table}".format(AutoMLSchema=AutoMLSchema,
+                                                                                model_training=model_training))
+        metrics_iters_val = [epochs_factor+mi for mi in iters[0]['metrics_iters']]
+
+        # casting same metrics_iters values for the fit_multiple run with the chosen configs
+        # in order to update overall info table later
+        plpy.execute("ALTER TABLE {model_training.model_info_table} " \
+                     "ADD COLUMN {AutoMLSchema.METRICS_ITERS} INTEGER[]".format(model_training=model_training,
+                                                                                AutoMLSchema=AutoMLSchema))
+        plpy.execute("UPDATE {model_training.model_info_table} SET {AutoMLSchema.METRICS_ITERS} = " \
+                     "ARRAY{metrics_iters_val}::INTEGER[] FROM {model_training.model_summary_table}".format(
+            model_training=model_training, metrics_iters_val=metrics_iters_val, AutoMLSchema=AutoMLSchema))
+        validation_update_q = "validation_metrics_final=t.validation_metrics_final, " \
+                                     "validation_loss_final=t.validation_loss_final, " \
+                                     "validation_metrics=a.validation_metrics || t.validation_metrics, " \
+                                     "validation_loss=a.validation_loss || t.validation_loss, " \
+            if self.validation_table else ""
+
+        # updates train/val info for any previously trained configs
+        plpy.execute("UPDATE {self.model_info_table} a SET " \
+                     "metrics_elapsed_time=a.metrics_elapsed_time || t.metrics_elapsed_time, " \
+                     "training_metrics_final=t.training_metrics_final, " \
+                     "training_loss_final=t.training_loss_final, " \
+                     "training_metrics= a.training_metrics || t.training_metrics, " \
+                     "training_loss= a.training_loss || t.training_loss, ".format(self=self) + validation_update_q +
+                     "metrics_iters=a.metrics_iters || t.metrics_iters " \
+                     "FROM {model_training.model_info_table} t " \
+                     "WHERE a.mst_key=t.mst_key".format(model_training=model_training))
+
+        # inserts info about metrics and validation for newly trained model configs
+        plpy.execute("INSERT INTO {self.model_info_table} SELECT * FROM {model_training.model_info_table} " \
+                     "WHERE {model_training.model_info_table}.mst_key NOT IN "
+                     "(SELECT mst_key FROM {self.model_info_table})".format(self=self,
+                                                                            model_training=model_training))
+
+        # removes metrics_iters column from temp mst info table, for further use in warm started model training
+        plpy.execute("ALTER TABLE {model_training.model_info_table} " \
+                     "DROP COLUMN {AutoMLSchema.METRICS_ITERS}".format(model_training=model_training,
+                                                                       AutoMLSchema=AutoMLSchema))
+
+    def update_model_selection_table(self):
+        """
+        Drops and re-createst the mst table to only include the best performing model configuration.

Review comment:
       typo in re-createst

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -17,17 +17,35 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from datetime import datetime
 import plpy
 import math
+from time import time
 
-from utilities.utilities import _assert
+from madlib_keras_validator import MstLoaderInputValidator
+from utilities.utilities import add_postfix, extract_keyvalue_params, _assert, _assert_equal
 from utilities.control import MinWarning
+from madlib_keras_fit_multiple_model import FitMultipleModel
+from madlib_keras_model_selection import MstSearch, ModelSelectionSchema
+from keras_model_arch_table import ModelArchSchema
+from utilities.validate_args import table_exists, drop_tables
+
 
 class AutoMLSchema:
     BRACKET = 's'
     ROUND = 'i'
     CONFIGURATIONS = 'n_i'
     RESOURCES = 'r_i'
+    HYPERBAND = 'hyperband'
+    R = 'R'
+    ETA = 'eta'
+    SKIP_LAST = 'skip_last'
+    LOSS_METRIC = 'training_loss_final'
+    TEMP_MST_TABLE = 'temp_mst_table'

Review comment:
       we should use `unique_string('temp_mst_table')` instead to create a unique name.

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table
+                                       'class_text',         -- Dependent variable
+                                       'attributes'          -- Independent variable
+           );
+
+DROP TABLE IF EXISTS iris_test_packed, iris_test_packed_summary;
+SELECT validation_preprocessor_dl('iris_test',          -- Source table
+                                         'iris_test_packed',   -- Output table

Review comment:
       indentation

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table

Review comment:
       indentation is off.

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table

Review comment:
       You can remove these call to training_preprocessor_dl/validation_preprocessor_dl as they are called in the setup file `madlib_keras_iris.setup.sql_in ` and you can use iris_data_packed as the input table
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="

Review comment:
       UPDATE is essentially DELETE/INSERT. In the past we saw that updating the table just doubles the size of the table since it is inside a UDF(single transaction). If this table is going to hold mst_key's >= the number of segments and update those many keys every diagonal that it trains, I think instead of calling update we should Truncate the table and Insert into it.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,

Review comment:
       The `FitMultipleModel` goes through additional validation logic of mst_table, output_table etc. for every diagonal that we are exploring, can we initialize the object(model_training) once and just call `model_training.run_training()` for each diagonal 

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="
+                     "t.model_weights FROM {model_training.original_model_output_table} t " \
+                     "WHERE a.mst_key=t.mst_key".format(self=self, model_training=model_training))
+
+        # inserts any newly trained configs
+        plpy.execute("INSERT INTO {self.model_output_table} SELECT * FROM {model_training.original_model_output_table} " \
+                     "WHERE {model_training.original_model_output_table}.mst_key NOT IN "
+                     "(SELECT mst_key FROM {self.model_output_table})".format(self=self,
+                                                                              model_training=model_training))
+
+    def update_model_output_info_table(self, i, model_training, initial_vals):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output info table.
+        :param i: outer diagonal loop iteration.
+        :param model_training: Fit Multiple function call object.
+        :param initial_vals: Dictionary of initial configurations and resources as part of the initial hyperband
+        schedule.
+        """
+        # normalizing factor for metrics_iters due to warm start
+        epochs_factor = sum([n[1] for n in initial_vals.values()][::-1][:i])
+        iters = plpy.execute("SELECT {AutoMLSchema.METRICS_ITERS} " \
+                             "FROM {model_training.model_summary_table}".format(AutoMLSchema=AutoMLSchema,
+                                                                                model_training=model_training))
+        metrics_iters_val = [epochs_factor+mi for mi in iters[0]['metrics_iters']]
+
+        # casting same metrics_iters values for the fit_multiple run with the chosen configs
+        # in order to update overall info table later
+        plpy.execute("ALTER TABLE {model_training.model_info_table} " \

Review comment:
       why do we alter table add/drop column every diagonal that we explore? Alter table rewrites the entire table.
   We should avoid it since we only care to add the metrics_iters values into `self.model_info_table`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r485844264



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+-- DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;

Review comment:
       we can delete these commented SQL




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] Advitya17 commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
Advitya17 commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r479873861



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,

Review comment:
       That may not be possible, as this function call takes in a different (reconstructed) mst table for each diagonal. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 commented on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
fmcquillan99 commented on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-690776040


   iris dataset
   
   (1)
   schedule:
   ```
    s | i | n_i | r_i 
   ---+---+-----+-----
    2 | 0 |   9 |   1
    2 | 1 |   3 |   3
    2 | 2 |   1 |   9
    1 | 0 |   3 |   3
    1 | 1 |   1 |   9
    0 | 0 |   3 |   9
   (6 rows)
   ```
   train:
   ```
   DROP TABLE IF EXISTS automl_output, automl_output_info, automl_output_summary, automl_mst_table, automl_mst_table_summary;
   
   SELECT madlib.madlib_keras_automl('iris_train_packed', 
                                     'automl_output', 
                                     'model_arch_library_iris', 
                                     'automl_mst_table',
                                     ARRAY[1], 
                                     $${'loss': ['categorical_crossentropy'], 'optimizer_params_list': [ {'optimizer': ['Adam'],'lr': [0.00999, 0.01001, 'log']}], 'metrics':['accuracy']}$$, 
                                     $${'batch_size': [8], 'epochs': [1]}$$,
                                     'hyperband', 
                                     'R=9, eta=3, skip_last=0',
                                     42,                  -- random state
                                     NULL,                  -- object table
                                     FALSE,                 -- use GPUs
                                     'iris_test_packed',    -- validation table
                                      1,                     -- metrics compute freq
                                     'a name is a name',                  -- name
                                     'a descr is a descr');                 -- descr
   ```
   produces:
   ```
   INFO:  *** Diagonally evaluating 9 configs under bracket=2 & round=0 with 1 iterations ***
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 1: 8.45693087578 sec
   DETAIL:  
   	Training set after iteration 1:
   	mst_key=2: metric=0.933333337307, loss=0.538550555706
   	mst_key=8: metric=0.316666662693, loss=1.03528738022
   	mst_key=3: metric=0.675000011921, loss=0.833593308926
   	mst_key=6: metric=0.658333361149, loss=0.825735092163
   	mst_key=1: metric=0.808333337307, loss=0.872115910053
   	mst_key=7: metric=0.625, loss=0.974291205406
   	mst_key=9: metric=0.658333361149, loss=0.830321311951
   	mst_key=4: metric=0.658333361149, loss=0.885693848133
   	mst_key=5: metric=0.341666668653, loss=1.09756851196
   	Validation set after iteration 1:
   	mst_key=2: metric=0.966666638851, loss=0.534965932369
   	mst_key=8: metric=0.366666674614, loss=1.03363573551
   	mst_key=3: metric=0.633333325386, loss=0.872259676456
   	mst_key=6: metric=0.699999988079, loss=0.799027740955
   	mst_key=1: metric=0.833333313465, loss=0.866640985012
   	mst_key=7: metric=0.5, loss=0.974297106266
   	mst_key=9: metric=0.699999988079, loss=0.829889714718
   	mst_key=4: metric=0.699999988079, loss=0.885949194431
   	mst_key=5: metric=0.300000011921, loss=1.10424804688
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  *** Diagonally evaluating 3 configs under bracket=2 & round=1, 3 configs under bracket=1 & round=0 with 3 iterations ***
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 1: 4.89221119881 sec
   DETAIL:  
   	Training set after iteration 1:
   	mst_key=9: metric=0.658333361149, loss=0.530654907227
   	mst_key=6: metric=0.774999976158, loss=0.531277298927
   	mst_key=11: metric=0.691666662693, loss=0.861804306507
   	mst_key=2: metric=0.716666638851, loss=0.426045030355
   	mst_key=12: metric=0.608333349228, loss=0.984588444233
   	mst_key=10: metric=0.324999988079, loss=1.17993462086
   	Validation set after iteration 1:
   	mst_key=9: metric=0.699999988079, loss=0.514933288097
   	mst_key=6: metric=0.800000011921, loss=0.519642591476
   	mst_key=11: metric=0.666666686535, loss=0.861222624779
   	mst_key=2: metric=0.733333349228, loss=0.385236173868
   	mst_key=12: metric=0.699999988079, loss=0.961848556995
   	mst_key=10: metric=0.366666674614, loss=1.16651546955
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 2: 5.0256319046 sec
   DETAIL:  
   	Training set after iteration 2:
   	mst_key=9: metric=0.966666638851, loss=0.432032138109
   	mst_key=6: metric=0.908333361149, loss=0.361528068781
   	mst_key=11: metric=0.858333349228, loss=0.686981141567
   	mst_key=2: metric=0.758333325386, loss=0.373835682869
   	mst_key=12: metric=0.600000023842, loss=0.89406478405
   	mst_key=10: metric=0.341666668653, loss=1.13211238384
   	Validation set after iteration 2:
   	mst_key=9: metric=0.933333337307, loss=0.442799389362
   	mst_key=6: metric=0.933333337307, loss=0.350269824266
   	mst_key=11: metric=0.833333313465, loss=0.694237530231
   	mst_key=2: metric=0.899999976158, loss=0.325422286987
   	mst_key=12: metric=0.699999988079, loss=0.868638038635
   	mst_key=10: metric=0.300000011921, loss=1.14590144157
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 3: 4.99592685699 sec
   DETAIL:  
   	Training set after iteration 3:
   	mst_key=9: metric=0.699999988079, loss=0.403143405914
   	mst_key=6: metric=0.949999988079, loss=0.292406588793
   	mst_key=11: metric=0.949999988079, loss=0.614805757999
   	mst_key=2: metric=0.658333361149, loss=0.662374198437
   	mst_key=12: metric=0.47499999404, loss=0.713567197323
   	mst_key=10: metric=0.40000000596, loss=1.08757543564
   	Validation set after iteration 3:
   	mst_key=9: metric=0.733333349228, loss=0.365277469158
   	mst_key=6: metric=0.899999976158, loss=0.303717941046
   	mst_key=11: metric=0.966666638851, loss=0.615865588188
   	mst_key=2: metric=0.633333325386, loss=0.758127868176
   	mst_key=12: metric=0.566666662693, loss=0.689917981625
   	mst_key=10: metric=0.300000011921, loss=1.09673452377
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  *** Diagonally evaluating 1 configs under bracket=2 & round=2, 1 configs under bracket=1 & round=1, 3 configs under bracket=0 & round=0 with 9 iterations ***
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 1: 4.01545906067 sec
   DETAIL:  
   	Training set after iteration 1:
   	mst_key=11: metric=0.949999988079, loss=0.535947144032
   	mst_key=13: metric=0.658333361149, loss=0.576903045177
   	mst_key=6: metric=0.725000023842, loss=0.413865655661
   	mst_key=15: metric=0.566666662693, loss=0.821544110775
   	mst_key=14: metric=0.608333349228, loss=1.00213348866
   	Validation set after iteration 1:
   	mst_key=11: metric=0.866666674614, loss=0.556911349297
   	mst_key=13: metric=0.699999988079, loss=0.577660501003
   	mst_key=6: metric=0.666666686535, loss=0.506598353386
   	mst_key=15: metric=0.600000023842, loss=0.782476782799
   	mst_key=14: metric=0.566666662693, loss=1.00861167908
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 2: 4.0810611248 sec
   DETAIL:  
   	Training set after iteration 2:
   	mst_key=11: metric=0.958333313465, loss=0.461322844028
   	mst_key=13: metric=0.658333361149, loss=0.488560408354
   	mst_key=6: metric=0.824999988079, loss=0.323733448982
   	mst_key=15: metric=0.875, loss=0.656310975552
   	mst_key=14: metric=0.658333361149, loss=0.641188919544
   	Validation set after iteration 2:
   	mst_key=11: metric=0.866666674614, loss=0.486970752478
   	mst_key=13: metric=0.699999988079, loss=0.468940109015
   	mst_key=6: metric=0.733333349228, loss=0.393083393574
   	mst_key=15: metric=0.800000011921, loss=0.679034292698
   	mst_key=14: metric=0.699999988079, loss=0.636842608452
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 3: 3.9402680397 sec
   DETAIL:  
   	Training set after iteration 3:
   	mst_key=11: metric=0.841666638851, loss=0.428030073643
   	mst_key=13: metric=0.933333337307, loss=0.341508597136
   	mst_key=6: metric=0.933333337307, loss=0.217096969485
   	mst_key=15: metric=0.941666662693, loss=0.388799399137
   	mst_key=14: metric=0.916666686535, loss=0.44380658865
   	Validation set after iteration 3:
   	mst_key=11: metric=0.733333349228, loss=0.495877951384
   	mst_key=13: metric=0.966666638851, loss=0.35888004303
   	mst_key=6: metric=0.933333337307, loss=0.196621760726
   	mst_key=15: metric=0.866666674614, loss=0.401055067778
   	mst_key=14: metric=0.933333337307, loss=0.44562998414
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 4: 4.27783608437 sec
   DETAIL:  
   	Training set after iteration 4:
   	mst_key=11: metric=0.899999976158, loss=0.348980903625
   	mst_key=13: metric=0.949999988079, loss=0.273842066526
   	mst_key=6: metric=0.824999988079, loss=0.307050615549
   	mst_key=15: metric=0.941666662693, loss=0.25471162796
   	mst_key=14: metric=0.966666638851, loss=0.266296118498
   	Validation set after iteration 4:
   	mst_key=11: metric=0.833333313465, loss=0.420216292143
   	mst_key=13: metric=0.866666674614, loss=0.317742973566
   	mst_key=6: metric=0.899999976158, loss=0.230505168438
   	mst_key=15: metric=0.899999976158, loss=0.281882911921
   	mst_key=14: metric=0.966666638851, loss=0.275415152311
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 5: 4.06583809853 sec
   DETAIL:  
   	Training set after iteration 5:
   	mst_key=11: metric=0.958333313465, loss=0.315752029419
   	mst_key=13: metric=0.708333313465, loss=0.4172565341
   	mst_key=6: metric=0.958333313465, loss=0.151342079043
   	mst_key=15: metric=0.975000023842, loss=0.179173886776
   	mst_key=14: metric=0.983333349228, loss=0.177039191127
   	Validation set after iteration 5:
   	mst_key=11: metric=0.866666674614, loss=0.342488378286
   	mst_key=13: metric=0.633333325386, loss=0.558800399303
   	mst_key=6: metric=0.966666638851, loss=0.153643578291
   	mst_key=15: metric=1.0, loss=0.174838155508
   	mst_key=14: metric=0.966666638851, loss=0.192256584764
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 6: 4.02019715309 sec
   DETAIL:  
   	Training set after iteration 6:
   	mst_key=11: metric=0.883333325386, loss=0.296552926302
   	mst_key=13: metric=0.899999976158, loss=0.248627394438
   	mst_key=6: metric=0.858333349228, loss=0.285588264465
   	mst_key=15: metric=0.949999988079, loss=0.167279005051
   	mst_key=14: metric=0.899999976158, loss=0.227244228125
   	Validation set after iteration 6:
   	mst_key=11: metric=0.833333313465, loss=0.401087760925
   	mst_key=13: metric=0.833333313465, loss=0.321891099215
   	mst_key=6: metric=0.733333349228, loss=0.421591818333
   	mst_key=15: metric=0.966666638851, loss=0.138948410749
   	mst_key=14: metric=0.800000011921, loss=0.34265729785
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 7: 4.27092909813 sec
   DETAIL:  
   	Training set after iteration 7:
   	mst_key=11: metric=0.891666650772, loss=0.248297154903
   	mst_key=13: metric=0.891666650772, loss=0.24527259171
   	mst_key=6: metric=0.949999988079, loss=0.12675113976
   	mst_key=15: metric=0.916666686535, loss=0.215490117669
   	mst_key=14: metric=0.958333313465, loss=0.127076357603
   	Validation set after iteration 7:
   	mst_key=11: metric=0.866666674614, loss=0.335224747658
   	mst_key=13: metric=0.766666650772, loss=0.349867284298
   	mst_key=6: metric=1.0, loss=0.120815463364
   	mst_key=15: metric=0.933333337307, loss=0.161028474569
   	mst_key=14: metric=0.899999976158, loss=0.196913182735
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 8: 4.16106319427 sec
   DETAIL:  
   	Training set after iteration 8:
   	mst_key=11: metric=0.958333313465, loss=0.21432980895
   	mst_key=13: metric=0.75, loss=0.440660595894
   	mst_key=6: metric=0.791666686535, loss=0.416923582554
   	mst_key=15: metric=0.933333337307, loss=0.161955311894
   	mst_key=14: metric=0.908333361149, loss=0.182602509856
   	Validation set after iteration 8:
   	mst_key=11: metric=0.866666674614, loss=0.264829039574
   	mst_key=13: metric=0.699999988079, loss=0.644149899483
   	mst_key=6: metric=0.899999976158, loss=0.281583458185
   	mst_key=15: metric=0.866666674614, loss=0.226034641266
   	mst_key=14: metric=0.866666674614, loss=0.33136048913
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 9: 4.84911680222 sec
   DETAIL:  
   	Training set after iteration 9:
   	mst_key=11: metric=0.975000023842, loss=0.221068292856
   	mst_key=13: metric=0.983333349228, loss=0.138952031732
   	mst_key=6: metric=0.816666662693, loss=0.403530687094
   	mst_key=15: metric=0.941666662693, loss=0.168528050184
   	mst_key=14: metric=0.958333313465, loss=0.14608065784
   	Validation set after iteration 9:
   	mst_key=11: metric=0.966666638851, loss=0.236671224236
   	mst_key=13: metric=0.966666638851, loss=0.161113128066
   	mst_key=6: metric=0.899999976158, loss=0.267582565546
   	mst_key=15: metric=0.966666638851, loss=0.120766565204
   	mst_key=14: metric=0.966666638851, loss=0.127617403865
   CONTEXT:  PL/Python function "madlib_keras_automl"
    madlib_keras_automl 
   ---------------------
    
   (1 row)
   
   Time: 93722.857 ms
   ```
   ![image](https://user-images.githubusercontent.com/10538173/92820545-ee707e00-f37e-11ea-927b-df37bdf50dce.png)
   
   OK


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r485862374



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +157,406 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib

Review comment:
       We should explicitly check for the params passed in as None to be set to the Default value for `automl_method `, `automl_params ` and  `use_gpus `. This is valid when we only want to pass in the  `object_table` for a custome loss function but we want the default automl_params, automl_method to be used.
   Else it fails as follows: 
   ```
   SELECT madlib_keras_automl('iris_data_packed', 'automl_output', 'iris_model_arch', 'automl_mst_table',                                                                                               ARRAY[1,2], $${'loss': ['test_custom_fn1'], 'optimizer_params_list': [ {'optimizer': ['Adagrad', 'Adam'],                                                                                                 'lr': [0.9, 0.95, 'log'], 'epsilon': [0.3, 0.5, 'log_near_one']}, {'optimizer': ['Adam', 'SGD'],                                                                                                          'lr': [0.6, 0.65, 'log']} ], 'metrics':['accuracy'] }$$, $${'batch_size': [2, 4], 'epochs': [3]}$$, NULL, NULL, NULL, 'test_custom_function_table', NULL);
   ERROR:  AttributeError: 'NoneType' object has no attribute 'lower' (plpy_elog.c:121)
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "madlib_keras_automl", line 21, in <module>
       schedule_loader = madlib_keras_automl.KerasAutoML(**globals())
     PL/Python function "madlib_keras_automl", line 42, in wrapper
     PL/Python function "madlib_keras_automl", line 200, in __init__
     PL/Python function "madlib_keras_automl", line 257, in validate_and_define_inputs
   PL/Python function "madlib_keras_automl"
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 edited a comment on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
fmcquillan99 edited a comment on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-690782523


   iris dataset
   
   (2)
   schedule:
   ```
    s | i | n_i | r_i 
   ---+---+-----+-----
    3 | 0 |  27 |   1
    3 | 1 |   9 |   3
    3 | 2 |   3 |   9
    3 | 3 |   1 |  27
    2 | 0 |   9 |   3
    2 | 1 |   3 |   9
    2 | 2 |   1 |  27
    1 | 0 |   6 |   9
    1 | 1 |   2 |  27
    0 | 0 |   4 |  27
   ```
   produces:
   ![image](https://user-images.githubusercontent.com/10538173/92822881-81aab300-f381-11ea-92fb-6b0f132f95bc.png)
   
   e.g., for bracket 2
   ![image](https://user-images.githubusercontent.com/10538173/92823106-c46c8b00-f381-11ea-967b-af9342f6690a.png)
   
   OK
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] Advitya17 commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
Advitya17 commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r479878180



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="

Review comment:
       In reference to one of your messages that truncate may not be the issue, is this still applicable?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 edited a comment on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
fmcquillan99 edited a comment on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-690776040


   iris dataset
   
   (1)
   schedule:
   ```
    s | i | n_i | r_i 
   ---+---+-----+-----
    2 | 0 |   9 |   1
    2 | 1 |   3 |   3
    2 | 2 |   1 |   9
    1 | 0 |   3 |   3
    1 | 1 |   1 |   9
    0 | 0 |   3 |   9
   (6 rows)
   ```
   train:
   ```
   DROP TABLE IF EXISTS automl_output, automl_output_info, automl_output_summary, automl_mst_table, automl_mst_table_summary;
   
   SELECT madlib.madlib_keras_automl('iris_train_packed', 
                                     'automl_output', 
                                     'model_arch_library_iris', 
                                     'automl_mst_table',
                                     ARRAY[1], 
                                     $${'loss': ['categorical_crossentropy'], 'optimizer_params_list': [ {'optimizer': ['Adam'],'lr': [0.00999, 0.01001, 'log']}], 'metrics':['accuracy']}$$, 
                                     $${'batch_size': [8], 'epochs': [1]}$$,
                                     'hyperband', 
                                     'R=9, eta=3, skip_last=0',
                                     42,                  -- random state
                                     NULL,                  -- object table
                                     FALSE,                 -- use GPUs
                                     'iris_test_packed',    -- validation table
                                      1,                     -- metrics compute freq
                                     'a name is a name',                  -- name
                                     'a descr is a descr');                 -- descr
   ```
   produces:
   ```
   INFO:  *** Diagonally evaluating 9 configs under bracket=2 & round=0 with 1 iterations ***
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 1: 8.45693087578 sec
   DETAIL:  
   	Training set after iteration 1:
   	mst_key=2: metric=0.933333337307, loss=0.538550555706
   	mst_key=8: metric=0.316666662693, loss=1.03528738022
   	mst_key=3: metric=0.675000011921, loss=0.833593308926
   	mst_key=6: metric=0.658333361149, loss=0.825735092163
   	mst_key=1: metric=0.808333337307, loss=0.872115910053
   	mst_key=7: metric=0.625, loss=0.974291205406
   	mst_key=9: metric=0.658333361149, loss=0.830321311951
   	mst_key=4: metric=0.658333361149, loss=0.885693848133
   	mst_key=5: metric=0.341666668653, loss=1.09756851196
   	Validation set after iteration 1:
   	mst_key=2: metric=0.966666638851, loss=0.534965932369
   	mst_key=8: metric=0.366666674614, loss=1.03363573551
   	mst_key=3: metric=0.633333325386, loss=0.872259676456
   	mst_key=6: metric=0.699999988079, loss=0.799027740955
   	mst_key=1: metric=0.833333313465, loss=0.866640985012
   	mst_key=7: metric=0.5, loss=0.974297106266
   	mst_key=9: metric=0.699999988079, loss=0.829889714718
   	mst_key=4: metric=0.699999988079, loss=0.885949194431
   	mst_key=5: metric=0.300000011921, loss=1.10424804688
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  *** Diagonally evaluating 3 configs under bracket=2 & round=1, 3 configs under bracket=1 & round=0 with 3 iterations ***
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 1: 4.89221119881 sec
   DETAIL:  
   	Training set after iteration 1:
   	mst_key=9: metric=0.658333361149, loss=0.530654907227
   	mst_key=6: metric=0.774999976158, loss=0.531277298927
   	mst_key=11: metric=0.691666662693, loss=0.861804306507
   	mst_key=2: metric=0.716666638851, loss=0.426045030355
   	mst_key=12: metric=0.608333349228, loss=0.984588444233
   	mst_key=10: metric=0.324999988079, loss=1.17993462086
   	Validation set after iteration 1:
   	mst_key=9: metric=0.699999988079, loss=0.514933288097
   	mst_key=6: metric=0.800000011921, loss=0.519642591476
   	mst_key=11: metric=0.666666686535, loss=0.861222624779
   	mst_key=2: metric=0.733333349228, loss=0.385236173868
   	mst_key=12: metric=0.699999988079, loss=0.961848556995
   	mst_key=10: metric=0.366666674614, loss=1.16651546955
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 2: 5.0256319046 sec
   DETAIL:  
   	Training set after iteration 2:
   	mst_key=9: metric=0.966666638851, loss=0.432032138109
   	mst_key=6: metric=0.908333361149, loss=0.361528068781
   	mst_key=11: metric=0.858333349228, loss=0.686981141567
   	mst_key=2: metric=0.758333325386, loss=0.373835682869
   	mst_key=12: metric=0.600000023842, loss=0.89406478405
   	mst_key=10: metric=0.341666668653, loss=1.13211238384
   	Validation set after iteration 2:
   	mst_key=9: metric=0.933333337307, loss=0.442799389362
   	mst_key=6: metric=0.933333337307, loss=0.350269824266
   	mst_key=11: metric=0.833333313465, loss=0.694237530231
   	mst_key=2: metric=0.899999976158, loss=0.325422286987
   	mst_key=12: metric=0.699999988079, loss=0.868638038635
   	mst_key=10: metric=0.300000011921, loss=1.14590144157
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 3: 4.99592685699 sec
   DETAIL:  
   	Training set after iteration 3:
   	mst_key=9: metric=0.699999988079, loss=0.403143405914
   	mst_key=6: metric=0.949999988079, loss=0.292406588793
   	mst_key=11: metric=0.949999988079, loss=0.614805757999
   	mst_key=2: metric=0.658333361149, loss=0.662374198437
   	mst_key=12: metric=0.47499999404, loss=0.713567197323
   	mst_key=10: metric=0.40000000596, loss=1.08757543564
   	Validation set after iteration 3:
   	mst_key=9: metric=0.733333349228, loss=0.365277469158
   	mst_key=6: metric=0.899999976158, loss=0.303717941046
   	mst_key=11: metric=0.966666638851, loss=0.615865588188
   	mst_key=2: metric=0.633333325386, loss=0.758127868176
   	mst_key=12: metric=0.566666662693, loss=0.689917981625
   	mst_key=10: metric=0.300000011921, loss=1.09673452377
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  *** Diagonally evaluating 1 configs under bracket=2 & round=2, 1 configs under bracket=1 & round=1, 3 configs under bracket=0 & round=0 with 9 iterations ***
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 1: 4.01545906067 sec
   DETAIL:  
   	Training set after iteration 1:
   	mst_key=11: metric=0.949999988079, loss=0.535947144032
   	mst_key=13: metric=0.658333361149, loss=0.576903045177
   	mst_key=6: metric=0.725000023842, loss=0.413865655661
   	mst_key=15: metric=0.566666662693, loss=0.821544110775
   	mst_key=14: metric=0.608333349228, loss=1.00213348866
   	Validation set after iteration 1:
   	mst_key=11: metric=0.866666674614, loss=0.556911349297
   	mst_key=13: metric=0.699999988079, loss=0.577660501003
   	mst_key=6: metric=0.666666686535, loss=0.506598353386
   	mst_key=15: metric=0.600000023842, loss=0.782476782799
   	mst_key=14: metric=0.566666662693, loss=1.00861167908
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 2: 4.0810611248 sec
   DETAIL:  
   	Training set after iteration 2:
   	mst_key=11: metric=0.958333313465, loss=0.461322844028
   	mst_key=13: metric=0.658333361149, loss=0.488560408354
   	mst_key=6: metric=0.824999988079, loss=0.323733448982
   	mst_key=15: metric=0.875, loss=0.656310975552
   	mst_key=14: metric=0.658333361149, loss=0.641188919544
   	Validation set after iteration 2:
   	mst_key=11: metric=0.866666674614, loss=0.486970752478
   	mst_key=13: metric=0.699999988079, loss=0.468940109015
   	mst_key=6: metric=0.733333349228, loss=0.393083393574
   	mst_key=15: metric=0.800000011921, loss=0.679034292698
   	mst_key=14: metric=0.699999988079, loss=0.636842608452
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 3: 3.9402680397 sec
   DETAIL:  
   	Training set after iteration 3:
   	mst_key=11: metric=0.841666638851, loss=0.428030073643
   	mst_key=13: metric=0.933333337307, loss=0.341508597136
   	mst_key=6: metric=0.933333337307, loss=0.217096969485
   	mst_key=15: metric=0.941666662693, loss=0.388799399137
   	mst_key=14: metric=0.916666686535, loss=0.44380658865
   	Validation set after iteration 3:
   	mst_key=11: metric=0.733333349228, loss=0.495877951384
   	mst_key=13: metric=0.966666638851, loss=0.35888004303
   	mst_key=6: metric=0.933333337307, loss=0.196621760726
   	mst_key=15: metric=0.866666674614, loss=0.401055067778
   	mst_key=14: metric=0.933333337307, loss=0.44562998414
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 4: 4.27783608437 sec
   DETAIL:  
   	Training set after iteration 4:
   	mst_key=11: metric=0.899999976158, loss=0.348980903625
   	mst_key=13: metric=0.949999988079, loss=0.273842066526
   	mst_key=6: metric=0.824999988079, loss=0.307050615549
   	mst_key=15: metric=0.941666662693, loss=0.25471162796
   	mst_key=14: metric=0.966666638851, loss=0.266296118498
   	Validation set after iteration 4:
   	mst_key=11: metric=0.833333313465, loss=0.420216292143
   	mst_key=13: metric=0.866666674614, loss=0.317742973566
   	mst_key=6: metric=0.899999976158, loss=0.230505168438
   	mst_key=15: metric=0.899999976158, loss=0.281882911921
   	mst_key=14: metric=0.966666638851, loss=0.275415152311
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 5: 4.06583809853 sec
   DETAIL:  
   	Training set after iteration 5:
   	mst_key=11: metric=0.958333313465, loss=0.315752029419
   	mst_key=13: metric=0.708333313465, loss=0.4172565341
   	mst_key=6: metric=0.958333313465, loss=0.151342079043
   	mst_key=15: metric=0.975000023842, loss=0.179173886776
   	mst_key=14: metric=0.983333349228, loss=0.177039191127
   	Validation set after iteration 5:
   	mst_key=11: metric=0.866666674614, loss=0.342488378286
   	mst_key=13: metric=0.633333325386, loss=0.558800399303
   	mst_key=6: metric=0.966666638851, loss=0.153643578291
   	mst_key=15: metric=1.0, loss=0.174838155508
   	mst_key=14: metric=0.966666638851, loss=0.192256584764
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 6: 4.02019715309 sec
   DETAIL:  
   	Training set after iteration 6:
   	mst_key=11: metric=0.883333325386, loss=0.296552926302
   	mst_key=13: metric=0.899999976158, loss=0.248627394438
   	mst_key=6: metric=0.858333349228, loss=0.285588264465
   	mst_key=15: metric=0.949999988079, loss=0.167279005051
   	mst_key=14: metric=0.899999976158, loss=0.227244228125
   	Validation set after iteration 6:
   	mst_key=11: metric=0.833333313465, loss=0.401087760925
   	mst_key=13: metric=0.833333313465, loss=0.321891099215
   	mst_key=6: metric=0.733333349228, loss=0.421591818333
   	mst_key=15: metric=0.966666638851, loss=0.138948410749
   	mst_key=14: metric=0.800000011921, loss=0.34265729785
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 7: 4.27092909813 sec
   DETAIL:  
   	Training set after iteration 7:
   	mst_key=11: metric=0.891666650772, loss=0.248297154903
   	mst_key=13: metric=0.891666650772, loss=0.24527259171
   	mst_key=6: metric=0.949999988079, loss=0.12675113976
   	mst_key=15: metric=0.916666686535, loss=0.215490117669
   	mst_key=14: metric=0.958333313465, loss=0.127076357603
   	Validation set after iteration 7:
   	mst_key=11: metric=0.866666674614, loss=0.335224747658
   	mst_key=13: metric=0.766666650772, loss=0.349867284298
   	mst_key=6: metric=1.0, loss=0.120815463364
   	mst_key=15: metric=0.933333337307, loss=0.161028474569
   	mst_key=14: metric=0.899999976158, loss=0.196913182735
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 8: 4.16106319427 sec
   DETAIL:  
   	Training set after iteration 8:
   	mst_key=11: metric=0.958333313465, loss=0.21432980895
   	mst_key=13: metric=0.75, loss=0.440660595894
   	mst_key=6: metric=0.791666686535, loss=0.416923582554
   	mst_key=15: metric=0.933333337307, loss=0.161955311894
   	mst_key=14: metric=0.908333361149, loss=0.182602509856
   	Validation set after iteration 8:
   	mst_key=11: metric=0.866666674614, loss=0.264829039574
   	mst_key=13: metric=0.699999988079, loss=0.644149899483
   	mst_key=6: metric=0.899999976158, loss=0.281583458185
   	mst_key=15: metric=0.866666674614, loss=0.226034641266
   	mst_key=14: metric=0.866666674614, loss=0.33136048913
   CONTEXT:  PL/Python function "madlib_keras_automl"
   INFO:  
   	Time for training in iteration 9: 4.84911680222 sec
   DETAIL:  
   	Training set after iteration 9:
   	mst_key=11: metric=0.975000023842, loss=0.221068292856
   	mst_key=13: metric=0.983333349228, loss=0.138952031732
   	mst_key=6: metric=0.816666662693, loss=0.403530687094
   	mst_key=15: metric=0.941666662693, loss=0.168528050184
   	mst_key=14: metric=0.958333313465, loss=0.14608065784
   	Validation set after iteration 9:
   	mst_key=11: metric=0.966666638851, loss=0.236671224236
   	mst_key=13: metric=0.966666638851, loss=0.161113128066
   	mst_key=6: metric=0.899999976158, loss=0.267582565546
   	mst_key=15: metric=0.966666638851, loss=0.120766565204
   	mst_key=14: metric=0.966666638851, loss=0.127617403865
   CONTEXT:  PL/Python function "madlib_keras_automl"
    madlib_keras_automl 
   ---------------------
    
   (1 row)
   
   Time: 93722.857 ms
   ```
   ![image](https://user-images.githubusercontent.com/10538173/92820545-ee707e00-f37e-11ea-927b-df37bdf50dce.png)
   
   e.g., for bracket 2
   ![image](https://user-images.githubusercontent.com/10538173/92820822-3db6ae80-f37f-11ea-9b53-9dd7cd4d055b.png)
   
   
   
   OK


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 commented on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
fmcquillan99 commented on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-688930405


   @Advitya17 What is the syntax error here please?
   ```
   SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
                              'automl_cifar10_output', 
                              'model_arch_library', 
                              'automl_cifar_10_mst_table',
                              ARRAY[1], 
                              $${'loss': ['categorical_crossentropy'], 
                                 'optimizer_params_list': [
                                     {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
                                     {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
                                'metrics':['accuracy']}$$, 
                              $${'batch_size': [64, 128], 'epochs': [1]}$$,
                              'hyperband', 
                              'R=9, eta=3, skip_last=2');
   ```
   ```
   InternalError: (psycopg2.errors.InternalError_) SyntaxError: invalid syntax (<unknown>, line 1) (plpython.c:5038)
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "madlib_keras_automl", line 21, in <module>
       schedule_loader = madlib_keras_automl.KerasAutoML(**globals())
     PL/Python function "madlib_keras_automl", line 42, in wrapper
     PL/Python function "madlib_keras_automl", line 215, in __init__
     PL/Python function "madlib_keras_automl", line 308, in find_hyperband_config
     PL/Python function "madlib_keras_automl", line 42, in wrapper
     PL/Python function "madlib_keras_automl", line 283, in __init__
     PL/Python function "madlib_keras_automl", line 48, in literal_eval
     PL/Python function "madlib_keras_automl", line 36, in parse
   PL/Python function "madlib_keras_automl"
   
   [SQL: SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
                              'automl_cifar10_output', 
                              'model_arch_library', 
                              'automl_cifar_10_mst_table',
                              ARRAY[1], 
                              $${'loss': ['categorical_crossentropy'], 
                                 'optimizer_params_list': [
                                     {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
                                     {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
                                'metrics':['accuracy']}$$, 
                              $${'batch_size': [64, 128], 'epochs': [1]}$$,
                              'hyperband', 
                              'R=9, eta=3, skip_last=2');]
   (Background on this error at: http://sqlalche.me/e/2j85)
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r485862374



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +157,406 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib

Review comment:
       We should explicitly check for the params passed in as None to be set to the Default value for `automl_method `, `automl_params ` and  `use_gpus `. This is a valid scenario when the user only wants to pass in the  `object_table` for a custom loss function but use the default automl_params, automl_method.
   Else it fails as follows: 
   ```
   SELECT madlib_keras_automl('iris_data_packed', 'automl_output', 'iris_model_arch', 'automl_mst_table',                                                                                               ARRAY[1,2], $${'loss': ['test_custom_fn1'], 'optimizer_params_list': [ {'optimizer': ['Adagrad', 'Adam'],                                                                                                 'lr': [0.9, 0.95, 'log'], 'epsilon': [0.3, 0.5, 'log_near_one']}, {'optimizer': ['Adam', 'SGD'],                                                                                                          'lr': [0.6, 0.65, 'log']} ], 'metrics':['accuracy'] }$$, $${'batch_size': [2, 4], 'epochs': [3]}$$, NULL, NULL, NULL, 'test_custom_function_table', NULL);
   ERROR:  AttributeError: 'NoneType' object has no attribute 'lower' (plpy_elog.c:121)
   CONTEXT:  Traceback (most recent call last):
     PL/Python function "madlib_keras_automl", line 21, in <module>
       schedule_loader = madlib_keras_automl.KerasAutoML(**globals())
     PL/Python function "madlib_keras_automl", line 42, in wrapper
     PL/Python function "madlib_keras_automl", line 200, in __init__
     PL/Python function "madlib_keras_automl", line 257, in validate_and_define_inputs
   PL/Python function "madlib_keras_automl"
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] Advitya17 commented on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
Advitya17 commented on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-688997702


   > @Advitya17 What is the syntax error here please?
   > 
   > ```
   > SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
   >                            'automl_cifar10_output', 
   >                            'model_arch_library', 
   >                            'automl_cifar_10_mst_table',
   >                            ARRAY[1], 
   >                            $${'loss': ['categorical_crossentropy'], 
   >                               'optimizer_params_list': [
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
   >                              'metrics':['accuracy']}$$, 
   >                            $${'batch_size': [64, 128], 'epochs': [1]}$$,
   >                            'hyperband', 
   >                            'R=9, eta=3, skip_last=2');
   > ```
   > 
   > ```
   > InternalError: (psycopg2.errors.InternalError_) SyntaxError: invalid syntax (<unknown>, line 1) (plpython.c:5038)
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_automl", line 21, in <module>
   >     schedule_loader = madlib_keras_automl.KerasAutoML(**globals())
   >   PL/Python function "madlib_keras_automl", line 42, in wrapper
   >   PL/Python function "madlib_keras_automl", line 215, in __init__
   >   PL/Python function "madlib_keras_automl", line 308, in find_hyperband_config
   >   PL/Python function "madlib_keras_automl", line 42, in wrapper
   >   PL/Python function "madlib_keras_automl", line 283, in __init__
   >   PL/Python function "madlib_keras_automl", line 48, in literal_eval
   >   PL/Python function "madlib_keras_automl", line 36, in parse
   > PL/Python function "madlib_keras_automl"
   > 
   > [SQL: SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
   >                            'automl_cifar10_output', 
   >                            'model_arch_library', 
   >                            'automl_cifar_10_mst_table',
   >                            ARRAY[1], 
   >                            $${'loss': ['categorical_crossentropy'], 
   >                               'optimizer_params_list': [
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
   >                              'metrics':['accuracy']}$$, 
   >                            $${'batch_size': [64, 128], 'epochs': [1]}$$,
   >                            'hyperband', 
   >                            'R=9, eta=3, skip_last=2');]
   > (Background on this error at: http://sqlalche.me/e/2j85)
   > ```
   
   - There needs to be a comma between the dictionaries specified inside optimizer_params_list
   - All the params (except optimizer) in the optimizer_params_list should be of the format [lower_bound, upper_bound, distribution_type] and you should specify `[0.0001, 0.001, 'log']` instead of `[0.001, 0.0001, 'log']`.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] Advitya17 commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
Advitya17 commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r479873400



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table

Review comment:
       Do you mean `iris_data_packed` or `iris_train_packed` for the source (i.e. training) table? I am currently using `iris_train_packed` and I need to have these function calls, otherwise this table is not defined while trying the run hyperband.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] Advitya17 edited a comment on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
Advitya17 edited a comment on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-688997702


   > @Advitya17 What is the syntax error here please?
   > 
   > ```
   > SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
   >                            'automl_cifar10_output', 
   >                            'model_arch_library', 
   >                            'automl_cifar_10_mst_table',
   >                            ARRAY[1], 
   >                            $${'loss': ['categorical_crossentropy'], 
   >                               'optimizer_params_list': [
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
   >                              'metrics':['accuracy']}$$, 
   >                            $${'batch_size': [64, 128], 'epochs': [1]}$$,
   >                            'hyperband', 
   >                            'R=9, eta=3, skip_last=2');
   > ```
   > 
   > ```
   > InternalError: (psycopg2.errors.InternalError_) SyntaxError: invalid syntax (<unknown>, line 1) (plpython.c:5038)
   > CONTEXT:  Traceback (most recent call last):
   >   PL/Python function "madlib_keras_automl", line 21, in <module>
   >     schedule_loader = madlib_keras_automl.KerasAutoML(**globals())
   >   PL/Python function "madlib_keras_automl", line 42, in wrapper
   >   PL/Python function "madlib_keras_automl", line 215, in __init__
   >   PL/Python function "madlib_keras_automl", line 308, in find_hyperband_config
   >   PL/Python function "madlib_keras_automl", line 42, in wrapper
   >   PL/Python function "madlib_keras_automl", line 283, in __init__
   >   PL/Python function "madlib_keras_automl", line 48, in literal_eval
   >   PL/Python function "madlib_keras_automl", line 36, in parse
   > PL/Python function "madlib_keras_automl"
   > 
   > [SQL: SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
   >                            'automl_cifar10_output', 
   >                            'model_arch_library', 
   >                            'automl_cifar_10_mst_table',
   >                            ARRAY[1], 
   >                            $${'loss': ['categorical_crossentropy'], 
   >                               'optimizer_params_list': [
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}
   >                                   {'optimizer': ['Adam'],'lr': [0.001, 0.0001, 'log']}],
   >                              'metrics':['accuracy']}$$, 
   >                            $${'batch_size': [64, 128], 'epochs': [1]}$$,
   >                            'hyperband', 
   >                            'R=9, eta=3, skip_last=2');]
   > (Background on this error at: http://sqlalche.me/e/2j85)
   > ```
   
   - There needs to be a comma between the dictionaries specified inside optimizer_params_list
   - All the params (except optimizer) in the optimizer_params_list should be of the format [lower_bound, upper_bound, distribution_type] and you should specify `[0.0001, 0.001, 'log']` instead of `[0.001, 0.0001, 'log']`.
   
   Thanks for the question though. When I tried it with the right syntax and got the assertion error related to the format of specifying distribution with `[0.0001, 0.001, 'log']`, I saw the error message I threw wasn't completely intuitive, so I can modify its language a bit. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta merged pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta merged pull request #513:
URL: https://github.com/apache/madlib/pull/513






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r480247892



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_automl.py_in
##########
@@ -138,3 +156,389 @@ class HyperbandSchedule():
                                       r_i_col=AutoMLSchema.RESOURCES,
                                       **locals())
             plpy.execute(insert_query)
+
+@MinWarning("warning")
+class KerasAutoML():
+    """The core AutoML function for running AutoML algorithms such as Hyperband.
+    This function executes the hyperband rounds 'diagonally' to evaluate multiple configurations together
+    and leverage the compute power of MPP databases such as Greenplum.
+    """
+    def __init__(self, schema_madlib, source_table, model_output_table, model_arch_table, model_selection_table,
+                 model_id_list, compile_params_grid, fit_params_grid, automl_method='hyperband',
+                 automl_params='R=6, eta=3, skip_last=0', random_state=None, object_table=None,
+                 use_gpus=False, validation_table=None, metrics_compute_frequency=None,
+                 name=None, description=None, **kwargs):
+        self.schema_madlib = schema_madlib
+        self.source_table = source_table
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(self.model_output_table, '_info')
+            self.model_summary_table = add_postfix(self.model_output_table, '_summary')
+        self.model_arch_table = model_arch_table
+        self.model_selection_table = model_selection_table
+        self.model_selection_summary_table = add_postfix(
+            model_selection_table, "_summary")
+        self.model_id_list = sorted(list(set(model_id_list)))
+        self.compile_params_grid = compile_params_grid
+        self.fit_params_grid = fit_params_grid
+
+        MstLoaderInputValidator(
+            model_arch_table=self.model_arch_table,
+            model_selection_table=self.model_selection_table,
+            model_selection_summary_table=self.model_selection_summary_table,
+            model_id_list=self.model_id_list,
+            compile_params_list=compile_params_grid,
+            fit_params_list=fit_params_grid,
+            object_table=object_table,
+            module_name='madlib_keras_automl'
+        )
+
+        self.automl_method = automl_method
+        self.automl_params = automl_params
+        self.random_state = random_state
+        self.validate_and_define_inputs()
+
+        self.object_table = object_table
+        self.use_gpus = use_gpus
+        self.validation_table = validation_table
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
+
+        if self.validation_table:
+            AutoMLSchema.LOSS_METRIC = 'validation_loss_final'
+
+        self.create_model_output_table()
+        self.create_model_output_info_table()
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            self.find_hyperband_config()
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                     {ModelArchSchema.MODEL_WEIGHTS} BYTEA,
+                                     {ModelArchSchema.MODEL_ARCH} JSON)
+                                    """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                               ModelArchSchema=ModelArchSchema)
+        with MinWarning('warning'):
+            plpy.execute(output_table_create_query)
+
+    def create_model_output_info_table(self):
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({ModelSelectionSchema.MST_KEY} INTEGER PRIMARY KEY,
+                                   {ModelArchSchema.MODEL_ID} INTEGER,
+                                   {ModelSelectionSchema.COMPILE_PARAMS} TEXT,
+                                   {ModelSelectionSchema.FIT_PARAMS} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[],
+                                   {AutoMLSchema.METRICS_ITERS} INTEGER[])
+                                       """.format(self=self, ModelSelectionSchema=ModelSelectionSchema,
+                                                  ModelArchSchema=ModelArchSchema, AutoMLSchema=AutoMLSchema)
+        with MinWarning('warning'):
+            plpy.execute(info_table_create_query)
+
+    def validate_and_define_inputs(self):
+
+        if AutoMLSchema.HYPERBAND.startswith(self.automl_method.lower()):
+            automl_params_dict = extract_keyvalue_params(self.automl_params,
+                                                         default_values={'R': 6, 'eta': 3, 'skip_last': 0},
+                                                         lower_case_names=False)
+            # casting dict values to int
+            for i in automl_params_dict:
+                automl_params_dict[i] = int(automl_params_dict[i])
+            _assert(len(automl_params_dict) >= 1 or len(automl_params_dict) <= 3,
+                    "DL: Only R, eta, and skip_last may be specified")
+            for i in automl_params_dict:
+                if i == AutoMLSchema.R:
+                    self.R = automl_params_dict[AutoMLSchema.R]
+                elif i == AutoMLSchema.ETA:
+                    self.eta = automl_params_dict[AutoMLSchema.ETA]
+                elif i == AutoMLSchema.SKIP_LAST:
+                    self.skip_last = automl_params_dict[AutoMLSchema.SKIP_LAST]
+                else:
+                    plpy.error("DL: {0} is an invalid param".format(i))
+            _assert(self.eta > 1, "DL: eta must be greater than 1")
+            _assert(self.R >= self.eta, "DL: R should not be less than eta")
+            self.s_max = int(math.floor(math.log(self.R, self.eta)))
+            _assert(self.skip_last >= 0 and self.skip_last < self.s_max+1, "DL: skip_last must be " +
+                    "non-negative and less than {0}".format(self.s_max))
+            # total number of resources/iterations (without reuse) per execution of Succesive Halving (n,r)
+            self.B = (self.s_max + 1) * self.R
+        else:
+            plpy.error("DL: Only hyperband is currently supported as the automl method")
+
+    def _is_valid_metrics_compute_frequency(self, num_iterations):
+        """
+        Utility function (same as that in the Fit Multiple function) to check validity of mcf value for computing
+        metrics during an AutoML algorithm run.
+        :param num_iterations: interations/resources to allocate for training.
+        :return: boolean on validity of the mcf value.
+        """
+        return self.metrics_compute_frequency is None or \
+               (self.metrics_compute_frequency >= 1 and \
+                self.metrics_compute_frequency <= num_iterations)
+
+    def find_hyperband_config(self):
+        """
+        Runs the diagonal hyperband algorithm.
+        """
+        initial_vals = {}
+
+        # get hyper parameter configs for each s
+        for s in reversed(range(self.s_max+1)):
+            n = int(math.ceil(int(self.B/self.R/(s+1))*math.pow(self.eta, s))) # initial number of configurations
+            r = self.R * math.pow(self.eta, -s) # initial number of iterations to run configurations for
+            initial_vals[s] = (n, int(round(r)))
+        self.start_training_time = self.get_current_timestamp()
+        random_search = MstSearch(self.model_arch_table, self.model_selection_table, self.model_id_list,
+                                  self.compile_params_grid, self.fit_params_grid, 'random',
+                                  sum([initial_vals[k][0] for k in initial_vals][self.skip_last:]), self.random_state,
+                                  self.object_table)
+        random_search.load() # for populating mst tables
+
+        # for creating the summary table for usage in fit multiple
+        plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_SUMMARY_TABLE} AS " \
+                     "SELECT * FROM {random_search.model_selection_summary_table}".format(
+            AutoMLSchema=AutoMLSchema, random_search=random_search))
+        ranges_dict = self.mst_key_ranges_dict(initial_vals)
+
+        # outer loop on diagonal
+        for i in range((self.s_max+1) - int(self.skip_last)):
+            # inner loop on s desc
+            temp_lst = []
+            configs_prune_lookup = {}
+            for s in range(self.s_max, self.s_max-i-1, -1):
+                n = initial_vals[s][0]
+                n_i = n * math.pow(self.eta, -i+self.s_max-s)
+                configs_prune_lookup[s] = int(round(n_i))
+                temp_lst.append("{0} configs under bracket={1} & round={2}".format(int(n_i), s, s-self.s_max+i))
+            plpy.info('*** Diagonally evaluating ' + ', '.join(temp_lst) + ' with {0} iterations ***'.format(
+                int(initial_vals[self.s_max-i][1])))
+
+            self.reconstruct_temp_mst_table(i, ranges_dict, configs_prune_lookup)
+            self.warm_start = int(i != 0)
+            num_iterations = int(initial_vals[self.s_max-i][1])
+            mcf = self.metrics_compute_frequency if self._is_valid_metrics_compute_frequency(num_iterations) else None
+            model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLSchema.TEMP_OUTPUT_TABLE,
+                                              AutoMLSchema.TEMP_MST_TABLE, num_iterations, self.use_gpus,
+                                              self.validation_table, mcf, self.warm_start, self.name, self.description)
+            self.update_model_output_table(model_training)
+            self.update_model_output_info_table(i, model_training,initial_vals)
+        self.end_training_time = self.get_current_timestamp()
+        self.update_model_selection_table()
+        self.generate_model_output_summary_table(model_training)
+        self.remove_temp_tables(model_training)
+
+    def get_current_timestamp(self):
+        """for start and end times for the chosen AutoML algorithm. Showcased in the output summary table"""
+        return datetime.fromtimestamp(time()).strftime('%Y-%m-%d %H:%M:%S')
+
+    def mst_key_ranges_dict(self, initial_vals):
+        """
+        Extracts the ranges of model configs (using mst_keys) belonging to / sampled as part of
+        executing a particular SHA bracket.
+        """
+        d = {}
+        for s_val in sorted(initial_vals.keys(), reverse=True): # going from s_max to 0
+            if s_val == self.s_max:
+                d[s_val] = (1, initial_vals[s_val][0])
+            else:
+                d[s_val] = (d[s_val+1][1]+1, d[s_val+1][1]+initial_vals[s_val][0])
+        return d
+
+    def reconstruct_temp_mst_table(self, i, ranges_dict, configs_prune_lookup):
+        """
+        Drops and Reconstructs a temp mst table for evaluation along particular diagonals of hyperband.
+        :param i: outer diagonal loop iteration.
+        :param ranges_dict: model config ranges to group by bracket number.
+        :param configs_prune_lookup: Lookup dictionary for configs to evaluate for a diagonal.
+        :return:
+        """
+        if i == 0:
+            _assert_equal(len(configs_prune_lookup), 1, "invalid args")
+            lower_bound, upper_bound = ranges_dict[self.s_max]
+            plpy.execute("CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} AS SELECT * FROM {self.model_selection_table} "
+                         "WHERE mst_key >= {lower_bound} AND mst_key <= {upper_bound}".format(self=self,
+                                                                                              AutoMLSchema=AutoMLSchema,
+                                                                                              lower_bound=lower_bound,
+                                                                                              upper_bound=upper_bound,))
+            return
+        # dropping and repopulating temp_mst_table
+        drop_tables([AutoMLSchema.TEMP_MST_TABLE])
+
+        # {mst_key} changed from SERIAL to INTEGER for safe insertions and preservation of mst_key values
+        create_query = """
+                        CREATE TABLE {AutoMLSchema.TEMP_MST_TABLE} (
+                            {mst_key} INTEGER,
+                            {model_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_id}, {compile_params}, {fit_params})
+                        );
+                       """.format(AutoMLSchema=AutoMLSchema,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
+        with MinWarning('warning'):
+            plpy.execute(create_query)
+
+        query = ""
+        new_configs = True
+        for s_val in configs_prune_lookup:
+            lower_bound, upper_bound = ranges_dict[s_val]
+            if new_configs:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_selection_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                lower_bound=lower_bound, upper_bound=upper_bound)
+                new_configs = False
+            else:
+                query += "INSERT INTO {AutoMLSchema.TEMP_MST_TABLE} SELECT mst_key, model_id, compile_params, fit_params " \
+                         "FROM {self.model_info_table} WHERE mst_key >= {lower_bound} " \
+                         "AND mst_key <= {upper_bound} ORDER BY {AutoMLSchema.LOSS_METRIC} " \
+                         "LIMIT {configs_prune_lookup_val};".format(self=self, AutoMLSchema=AutoMLSchema,
+                                                                    lower_bound=lower_bound, upper_bound=upper_bound,
+                                                                    configs_prune_lookup_val=configs_prune_lookup[s_val])
+        plpy.execute(query)
+
+    def update_model_output_table(self, model_training):
+        """
+        Updates gathered information of a hyperband diagonal run to the overall model output table.
+        :param model_training: Fit Multiple function call object.
+        """
+        # updates model weights for any previously trained configs
+        plpy.execute("UPDATE {self.model_output_table} a SET model_weights="

Review comment:
       Yes, this is still applicable




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 commented on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
fmcquillan99 commented on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-690782523


   iris dataset
   
   (1)
   schedule:
   ```
    s | i | n_i | r_i 
   ---+---+-----+-----
    3 | 0 |  27 |   1
    3 | 1 |   9 |   3
    3 | 2 |   3 |   9
    3 | 3 |   1 |  27
    2 | 0 |   9 |   3
    2 | 1 |   3 |   9
    2 | 2 |   1 |  27
    1 | 0 |   6 |   9
    1 | 1 |   2 |  27
    0 | 0 |   4 |  27
   ```
   produces:
   ![image](https://user-images.githubusercontent.com/10538173/92822881-81aab300-f381-11ea-92fb-6b0f132f95bc.png)
   
   e.g., for bracket 2
   ![image](https://user-images.githubusercontent.com/10538173/92823106-c46c8b00-f381-11ea-967b-af9342f6690a.png)
   
   OK
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta merged pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta merged pull request #513:
URL: https://github.com/apache/madlib/pull/513






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 commented on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
fmcquillan99 commented on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-690789241


   cifar-10 dataset
   
   (3)
   schedule:
   ```
    s | i | n_i | r_i 
   ---+---+-----+-----
    2 | 0 |   9 |   1
    2 | 1 |   3 |   3
    2 | 2 |   1 |   9
    1 | 0 |   3 |   3
    1 | 1 |   1 |   9
    0 | 0 |   3 |   9
   (6 rows)
   ```
   train:
   ```
   DROP TABLE IF EXISTS automl_cifar10_output, automl_cifar10_output_info, automl_cifar10_output_summary, automl_cifar_10_mst_table, automl_cifar_10_mst_table_summary;
   
   SELECT madlib.madlib_keras_automl('cifar10_train_packed', 
                              'automl_cifar10_output', 
                              'model_arch_library', 
                              'automl_cifar_10_mst_table',
                              ARRAY[1], 
                              $${'loss': ['categorical_crossentropy'], 
                                 'optimizer_params_list': [
                                     {'optimizer': ['Adam'],'lr': [0.0001, 0.001, 'log']},
                                     {'optimizer': ['RMSprop'],'lr': [0.0001, 0.001, 'log'],'decay: [1e-6]}],
                                'metrics':['accuracy']}$$, 
                              $${'batch_size': [64, 128], 'epochs': [1]}$$,
                              'hyperband', 
                              'R=9, eta=3, skip_last=0',                                                              
                               42,                    -- random state
                               NULL,                  -- object table
                               FALSE,                 -- use GPUs
                               'cifar10_test_packed', -- validation table
                               1,                     -- metrics compute freq
                               NULL,                  -- name
                               NULL);                 -- descr
   ```
   produces:
   ![image](https://user-images.githubusercontent.com/10538173/92825871-e7e50500-f384-11ea-8214-628c71355c83.png)
   
   for s=2:
   ![image](https://user-images.githubusercontent.com/10538173/92826038-18c53a00-f385-11ea-8e11-d13a6e6bccdd.png)
   
   OK


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta commented on a change in pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta commented on a change in pull request #513:
URL: https://github.com/apache/madlib/pull/513#discussion_r480247521



##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
##########
@@ -21,6 +21,347 @@
 
 m4_include(`SQLCommon.m4')
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+--------------------------- MADLIB KERAS AUTOML HYPERBAND TEST CASES ---------------------------
+
+DROP TABLE IF EXISTS iris_train_packed, iris_train_packed_summary;
+SELECT training_preprocessor_dl('iris_train',         -- Source table
+                                       'iris_train_packed',  -- Output table

Review comment:
       If you look at the file `madlib_keras_iris.setup.sql_in`, there is already a table `iris_data_packed ` which is equivalent to the `iris_train_packed ` that you are creating here by calling `training_preprocessro_dl()`. Since this file `madlib_keras_iris.setup.sql_in` (L24:27), is executed as part of this test, you don't need to re run  training_preprocessor_dl and validation_preprocessor_dl you can directly use the already preprocessed tables `iris_data_packed` and `iris_data_val` 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] khannaekta merged pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
khannaekta merged pull request #513:
URL: https://github.com/apache/madlib/pull/513


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [madlib] fmcquillan99 commented on pull request #513: DL: [AutoML] Add support for 'diagonal' Hyperband optimized for MPP

Posted by GitBox <gi...@apache.org>.
fmcquillan99 commented on pull request #513:
URL: https://github.com/apache/madlib/pull/513#issuecomment-690789601


   I tested a bunch of error conditions and corner cases and it seems fine to me.  Great job on this PR!
   
   LGTM


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org