You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/12/04 19:17:19 UTC

[GitHub] [madlib] kaknikhil commented on a change in pull request #525: DL: Model Hopper Refactor

kaknikhil commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r530005710



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -196,72 +195,110 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
+        self.model_input_tbl = unique_string(desp='model_input')
+        self.schedule_tbl = unique_string(desp='schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        DEBUG.plpy.info("init_dist_keys = {0}".format(self.dist_keys))
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        DEBUG.plpy.info("sorted_dist_keys = {0}".format(sorted(self.dist_keys)))
+        DEBUG.plpy.info("max_dist_key = {0}".format(self.max_dist_key))
+        self.extra_dist_keys = []
+
+        num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Comma-separated list of the mst_keys, including NULL's
+        #  This will be used to pass the mst keys to the db as
+        #  a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
             self.metrics_elapsed_start_time = time.time()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
+                                    CREATE TABLE {self.original_model_output_tbl} AS

Review comment:
       why not drop the original model output table and then rename `model_output_tbl` to  `original_model_output_tbl` ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""

Review comment:
       Is this plpy.execute only for the plpy info commands ? If yes, then maybe we should consider if we really need to run this plpy.execute and the plpy info prints

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -64,6 +64,13 @@ def reset_cuda_env(value):
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
             del os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
+def enable_xla():
+    os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
+    try:
+        K.tf.config.optimizer.set_jit(True)
+    except:
+        plpy.warning("This version of tensorflow does not support XLA auto-cluster JIT optimization.  HINT:  upgrading tensorflow may improve performance.")

Review comment:
       maybe we should be more specific with the upgrading hint since we don't support tf 2.0 yet.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -518,111 +567,130 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     # Fit segment model on data
     #TODO consider not doing this every time
     fit_params = parse_and_validate_fit_params(fit_params)
-    segment_model.fit(x_train, y_train, **fit_params)
+    with K.tf.device(device_name):
+        segment_model.fit(x_train, y_train, **fit_params)
 
     # Aggregating number of images, loss and accuracy
     agg_image_count += len(x_train)
+    SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
+
     if is_last_row:
+        SD[SD_STORE.AGG_IMAGE_COUNT] = 0  # Must be reset after each pass through images
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
             clear_keras_session(sess)
 
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_transition_time_|{}|".format(trans_exit_time - trans_enter_time))
+
+    SD[SD_STORE.TRANS_EXIT_TIME] = trans_exit_time
     return return_state
 
-def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
-                             independent_var_shape, model_architecture,
-                             compile_params, fit_params, dist_key, dist_key_mapping,
-                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
-                             accessible_gpus_for_seg, prev_serialized_weights,
-                             is_final_training_call, custom_function_map=None, **kwargs):
+def fit_multiple_transition_caching(
+    dependent_var, independent_var, dependent_var_shape, independent_var_shape,
+    model_architecture, compile_params, fit_params, dist_key, dist_key_mapping,
+    current_seg_id, segments_per_host, images_per_seg, use_gpus, accessible_gpus_for_seg,
+    serialized_weights, is_final_training_call, custom_function_map=None, **kwargs):
     """
     This transition function is called when caching is called for
     madlib_keras_fit_multiple_model().
-    The input params: dependent_var, independent_var are passed in
-    as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very first hop
+    The input params: dependent_var, independent_var,
+    dependent_var_shape and independent_var_shape are passed
+    in as None for all hops except the very first hop
     Some things to note in this function are:
-    - prev_serialized_weights can be passed in as None for the
+    - weights can be passed in as None for the
       very first hop and the final training call
     - x_train, y_train and cache_set is cleared from SD for
-      final_training_call = TRUE
+      is_final_training_call = True
     """
-    if not state:
-        agg_image_count = 0
+    SD = kwargs['SD']
+
+    trans_enter_time = time.time()
+    trans_exit_time = None
+    if SD_STORE.TRANS_EXIT_TIME in SD:
+        trans_exit_time = SD[SD_STORE.TRANS_EXIT_TIME]
+        SD[SD_STORE.TRANS_EXIT_TIME] = trans_enter_time
+
+    if SD_STORE.AGG_IMAGE_COUNT in SD:
+        agg_image_count = SD[SD_STORE.AGG_IMAGE_COUNT]
     else:
-        agg_image_count = float(state)
+        agg_image_count = 0
+        SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
 
-    SD = kwargs['SD']
-    is_cache_set = 'cache_set' in SD
+    if agg_image_count > 0:
+        if trans_exit_time:
+            DEBUG.plpy.info("|_gpdb_btw_rows_time_|{}|".format(trans_enter_time - trans_exit_time))
+    else:
+        if trans_exit_time:
+            DEBUG.plpy.info("|_gpdb_btw_hops_end_time_|{}|".format(trans_enter_time - trans_exit_time))
 
     # Prepare the data
-    if is_cache_set:
+    if dependent_var_shape is None:
         if 'x_train' not in SD or 'y_train' not in SD:
             plpy.error("cache not populated properly.")
-        total_images = None
         is_last_row = True
+        total_images = None
     else:
-        if not independent_var or not dependent_var:
-            return state
-        if 'x_train' not in SD:
+        if 'x_train' not in SD or 'y_train' not in SD:
             SD['x_train'] = list()
             SD['y_train'] = list()
+
         agg_image_count += independent_var_shape[0]
-        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
-                                                          images_per_seg)
+        SD[SD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+        total_images = get_image_count_per_seg_from_array(
+            dist_key_mapping.index(dist_key), images_per_seg
+        )
         is_last_row = agg_image_count == total_images
-        if is_last_row:
-            SD['cache_set'] = True
         x_train_current = np_array_float32(independent_var, independent_var_shape)
         y_train_current = np_array_int16(dependent_var, dependent_var_shape)
         SD['x_train'].append(x_train_current)
         SD['y_train'].append(y_train_current)
 
     # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
     # But if the weights are None, we do not want to set any model. So early return in that case
-    if prev_serialized_weights is None:
+    if serialized_weights is None:
         if is_final_training_call:
             del SD['x_train']
             del SD['y_train']
-            del SD['cache_set']
-        return float(agg_image_count)
+        return None
 
     segment_model = None
     if is_last_row:
         device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
-        segment_model, sess = get_init_model_and_sess(SD, device_name,
-                                                      accessible_gpus_for_seg[current_seg_id],
-                                                      segments_per_host,
-                                                      model_architecture, compile_params,
-                                                      custom_function_map)
-        set_model_weights(segment_model, prev_serialized_weights)
+        with K.tf.device(device_name):

Review comment:
       we already wrap the code with K.tf.device(device_name): inside the get_init_model_and_sess function. We don't really need to do it here as well.
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
##########
@@ -1793,13 +1793,23 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     segments_per_host           INTEGER,
     images_per_seg              INTEGER[],
     use_gpus                    BOOLEAN,
-    accessible_gpus_for_seg                INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
     prev_serialized_weights     BYTEA,
     is_final_iteration          BOOLEAN,
-    custom_function_map        BYTEA
+    custom_function_map         BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_transition(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'TransAggDetail' + detail

Review comment:
       Instead of adding`TransAggDetail` to the error message, why not use a keyword that is more relevant for the fit multiple UDF ? If there is a specific reason for adding this keyword, I would recommend documenting it.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""

Review comment:
       Same as the previous comment
   
   `Is this plpy.execute only for the plpy info commands ? If yes, then maybe we should consider if we really need to run this plpy.execute and the plpy info prints`
   

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -629,149 +794,187 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
-        ind_shape_col = self.ind_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+            hop_query = """
+                CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                    SELECT o.{self.mst_key_col},
+                           o.{self.model_weights_col},
+                           o.{self.model_arch_col},
+                           o.{self.compile_params_col},
+                           o.{self.fit_params_col},
+                           o.{self.object_map_col},
+                           s.{self.dist_key_col}
+                    FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                        ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                    DISTRIBUTED BY ({self.dist_key_col});
+            """.format(self=self)
+
+            DEBUG.plpy.execute(hop_query)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            DEBUG.start_timing("truncate_output")
+            self.truncate_and_drop(self.model_output_tbl)
+            DEBUG.print_timing("truncate_output")
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
+        ind_shape = self.ind_shape_col
+        dep_shape = self.dep_shape_col
         dep_var = mb_dep_var_col
         indep_var = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape = ind_shape = 'NULL'
+                dep_var = indep_var = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
+
+        res = plpy.execute("""
+            SELECT count(*)
+            FROM {self.model_input_tbl}
+        """.format(self=self))
+        if res:
+            DEBUG.plpy.info("rows in model_input table: {}".format(res[0]['count']))
+        else:
+            DEBUG.plpy.error("No rows in model_input table!")
+
+#TODO: prepare this statement once, then just fill in the params with execute()
+#      on all the rest of the hops / iterations
+
+        DEBUG.start_timing("udf")
+        udf_query = plpy.prepare("""
+            CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+            SELECT
+                model_in.{self.mst_key_col},
+                CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                THEN
+                    model_in.{self.model_weights_col}
+                ELSE
+                    {self.schema_madlib}.fit_transition_multiple_model(
+                        {dep_var_col},
+                        {indep_var_col},
+                        {dep_shape},
+                        {ind_shape},
+                        model_in.{self.model_arch_col}::TEXT,
+                        model_in.{self.compile_params_col}::TEXT,
+                        model_in.{self.fit_params_col}::TEXT,
+                        src.{self.dist_key_col},
+                        ARRAY{self.dist_key_mapping},
+                        src.{self.gp_segment_id_col},
+                        {self.segments_per_host},
+                        ARRAY{self.images_per_seg_train},
+                        {self.use_gpus}::BOOLEAN,
+                        ARRAY{self.accessible_gpus_for_seg},
+                        model_in.{self.model_weights_col}::BYTEA,
+                        {self.is_final_training_call}::BOOLEAN,
+                        {use_caching}::BOOLEAN,
+                        model_in.{self.object_map_col}::BYTEA
+                    )
+                END::BYTEA AS {self.model_weights_col},
+                model_in.{self.model_arch_col},
+                model_in.{self.compile_params_col},
+                model_in.{self.fit_params_col},
+                model_in.{self.object_map_col},
+                model_in.{self.dist_key_col}
+            FROM {self.model_input_tbl} model_in
+                FULL JOIN {source_table} src
+                USING ({self.dist_key_col}) 
+            DISTRIBUTED BY({self.dist_key_col})
+            """.format(dep_var_col=dep_var,
+                       indep_var_col=indep_var,
+                       dep_shape=dep_shape,
+                       ind_shape=ind_shape,
                        use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
                        source_table=source_table,
-                       where_clause=where_clause,
                        self=self
                        )
-        plpy.execute(uda_query)
+        )
+
+        try:
+            plpy.execute(udf_query)
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'TransAggDetail' in msg:
+                raise e
+            e.message, detail = msg.split('TransAggDetail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_key, {self.model_weights_col} IS NOT NULL AS weights
+                FROM {self.model_output_tbl}
+        """.format(self=self))
+        if res:
+            null_msts = len([None for row in res if row['mst_key'] is None])
+            null_weights = len([None for row in res if row['weights'] is False])
+            DEBUG.plpy.info(
+                "{} rows total ({} mst_key=NULL and {} weights=NULL) in model_output table."\
+                    .format(res.nrows(), null_msts, null_weights))
+        else:
+            plpy.error("No rows in output of UDF!")
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))

Review comment:
       Instead of running a DELETE command, can't we filter out the null model_weights when executing the udf query?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)

Review comment:
       In the current implementation of `DEBUG.plpy.execute`, we still execute the query but don't print the plan or the timing.
   It's a bit confusing to see the word DEBUG in front of plpy.execute and not assume that this query will only be executed when the debug flag is turned on.
   
   So I would suggest the following
   1. Reduce the number of DEBUG.plpy.executes to only a couple maybe just the hop and the udf query
   2. Not use the word DEBUG but something like `plpy_execute` so as to not confuse the reader

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -495,21 +500,45 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         b. keras session is cleared at the end of the final iteration,
         i.e, last row of last iteration.
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape:
+        plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state
+
     SD = kwargs['SD']
+
+    trans_enter_time = time.time()
+
+    trans_exit_time = None
+    if SD_STORE.TRANS_EXIT_TIME in SD:
+        trans_exit_time = SD[SD_STORE.TRANS_EXIT_TIME]
+        SD[SD_STORE.TRANS_EXIT_TIME] = None
+
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
 
-    segment_model, sess = get_init_model_and_sess(SD, device_name,
-                                                  accessible_gpus_for_seg[current_seg_id],
-                                                  segments_per_host,
-                                                  model_architecture, compile_params,
-                                                  custom_function_map)
-    if not state:
+    with K.tf.device(device_name):

Review comment:
       we already wrap the code with `K.tf.device(device_name):` inside the `get_init_model_and_sess` function. We don't really need to do it here as well. 

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +377,307 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
 
-    def create_model_output_table_warm_start(self):
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.mst_key_col}, {self.dist_key_col}))
+                                     DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:

Review comment:
       1. Maybe consider renaming these two functions since their only purpose is not to load the weights but also model_arch, compile_params, fit_params etc.
   1. Also consider not including the keyword `xfer_learning` in the function name since it is also used in the case where model_weights are null in the model arch table.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -81,6 +88,7 @@ def set_keras_session(device_name, gpu_count, segments_per_host):
     with K.tf.device(device_name):
         session = get_keras_session(device_name, gpu_count, segments_per_host)
         K.set_session(session)
+        enable_xla()

Review comment:
       enable_xla() is already called inside of `get_keras_session`. Since set_keras_session calls `get_keras_session`, we don't really need to call enable_xla() here

##########
File path: src/ports/postgres/modules/utilities/debug.py_in
##########
@@ -0,0 +1,145 @@
+import plpy as plpy_orig
+import time
+from deep_learning.madlib_keras_model_selection import ModelSelectionSchema
+from deep_learning.madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
+
+mst_key_col = ModelSelectionSchema.MST_KEY
+dist_key_col = DISTRIBUTION_KEY_COLNAME
+
+start_times = dict()
+timings_enabled = False
+
+def start_timing(msg, force=False):
+    if timings_enabled or force:
+        start_times[msg] = time.time()
+        plpy_orig.info("|_{}_time_HDR|Elapsed (s)|Current|Current (s)|Start|Start (s)|".format(msg))
+
+def print_timing(msg, force=False):
+    if timings_enabled or force:
+        try:
+            start_time = start_times[msg]
+        except:
+            raise Exception(
+                "print_timing({msg}) called with no start_timing({msg})!".format(msg=msg)
+            )
+        current_time = time.time() 
+        plpy_orig.info(
+            '|_{0}_time|{1}|{2}|{3}|{4}|{5}'.format(
+                msg,
+                current_time - start_time,
+                time.ctime(current_time),
+                current_time,
+                time.ctime(start_time),
+                start_time
+            )
+        )
+
+mst_keys_enabled = False
+def print_mst_keys(table, label, force=False):
+    if not (mst_keys_enabled or force):
+        return
+
+    res = plpy_orig.execute("""
+        SELECT gp_segment_id AS seg_id,
+               {mst_key_col},
+               {dist_key_col}
+        FROM {table} ORDER BY {dist_key_col}
+    """.format(dist_key_col=dist_key_col,
+               table=table,
+               mst_key_col=mst_key_col))
+
+    plpy_orig.info("|_MST_KEYS_{label}_HDR|mst_key|seg_id|dist_key|table".format(**locals()))
+    if not res:
+        plpy_orig.error("{table} is empty!  Aborting".format(table=table))
+
+    for r in res:
+        seg_id = r['seg_id']
+        mst_key = r['mst_key']
+        dist_key = r[dist_key_col]
+        plpy_orig.info("|_MST_KEYS_{label}|{mst_key}|{seg_id}|{dist_key}|{table}".format(**locals()))
+
+plpy_execute_enabled = False
+def plpy_execute(*args, **kwargs):
+    """ debug.plpy.execute(sql, ..., force=False)
+
+        Replace plpy.execute(sql, ...) with
+        debug.plpy.execute(sql, ...) to debug
+        a query.  Shows the query itself, the
+        EXPLAIN of it, and how long the query
+        takes to execute.
+    """
+
+    force = False
+    if 'force' in kwargs:
+        force = kwargs['force']
+        del kwargs['force']
+
+    plpy = plpy_orig # override global plpy,
+                     # to avoid infinite recursion
+
+    if not (plpy_execute_enabled or force):
+        return plpy.execute(*args, **kwargs)
+
+    if len(args) > 0:
+        sql = args[0]
+    else:
+        raise TypeError('debug.plpy.execute() takes at least 1 parameter, 0 passed')
+
+    if type(sql) == str: # can't print if a PLyPlan object
+        plpy.info(sql)
+
+        # Print EXPLAIN of sql command
+        res = plpy.execute("EXPLAIN " + sql, *args[1:], **kwargs)
+        for r in res:
+            plpy.info(r['QUERY PLAN'])
+
+    # Run actual sql command, with timing
+    start = time.time()
+    res = plpy.execute(*args, **kwargs)
+
+    # Print how long execution of query took
+    plpy.info("Query took {0}s".format(time.time() - start))
+    if res:
+        plpy.info("Query returned {} row(s)".format(len(res)))
+    else:
+        plpy.info("Query returned 0 rows")
+    return res
+
+plpy_info_enabled = False
+def plpy_info(*args, **kwargs):
+    """ plpy_info(..., force=False)
+
+      plpy.info() if enabled, otherwise do nothing   
+    """
+
+    force = False
+    if 'force' in kwargs:
+        force = kwargs['force']
+        del kwargs['force']
+
+    if plpy_info_enabled or force:
+        plpy_orig.info(*args, **kwargs)
+
+plpy_debug_enabled = False
+def plpy_debug(*args, **kwargs):
+    """ debug.plpy.debug(..., force=False)
+
+        Behaves like plpy.debug() if disabled (printing only
+        if DEBUG level is set high enough), but becomes a
+        plpy.info() if enabled.
+    """
+
+    force = False
+    if 'force' in kwargs:
+        force = kwargs['force']
+        del kwargs['force']
+
+    if plpy_debug_enabled or force:
+        plpy_orig.info(*args, **kwargs)
+    else:
+        plpy_orig.debug(*args, **kwargs)
+
+class plpy:
+    execute = staticmethod(plpy_execute)
+    info = staticmethod(plpy_info)
+    debug = staticmethod(plpy_debug)

Review comment:
       Since this function is never used, do you think it makes sense to can delete this line and the `plpy_debug` function as well (given that developers can still use plpy_info for printing purposes) ?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
##########
@@ -337,183 +376,308 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        self.prev_dist_key_col = '__prev_dist_key__'
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        DEBUG.plpy.execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if not hasattr(self, 'rotate_schedule_plan'):
+            self.next_schedule_tbl = unique_string('next_schedule')
+            rotate_schedule_tbl_query = """
+                CREATE TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl};
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        DEBUG.plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
 
-    def create_model_output_table_warm_start(self):
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        DEBUG.plpy.execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(self.model_info_tbl))
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        DEBUG.plpy.execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.dist_key_col}, {self.mst_key_col})
+                                    )
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:
+            self.load_warm_start_weights()
+        else:  # Note:  We only support xfer learning when warm_start=False
+            self.load_xfer_learning_weights()
+
+        res = DEBUG.plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+        """.format(self=self))
+       
+        if res:
+            initialized_msts = set([ row['mst_keys'] for row in res ])
+        else:
+            initialized_msts = set()
+
+        DEBUG.plpy.info("Pre-initialized mst keys: {}".format(initialized_msts))

Review comment:
       We should try to reduce the frequency of `DEBUG.plpy.info` code lines since they can be feature specific. Having a lot of these code lines might make the code slightly harder to read. 
   Any developer working on a feature can add their own `DEBUG.plpy.info` as needed.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org