You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by do...@apache.org on 2021/01/20 00:18:32 UTC

[madlib] 01/03: DL: Major Refactor of Model Hopper

This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit bc1b6e12376d3b9d484cf7d49f5918207599fedb
Author: Domino Valdano <dv...@vmware.com>
AuthorDate: Wed Sep 30 19:47:51 2020 -0700

    DL: Major Refactor of Model Hopper
    
    JIRA: MADLIB-1428
    
    - Use only 2 temporary tables (model_input_tbl & model_output_tbl)
      for moving the model weights around during hopping and training,
      instead of 3 (mst_weights_tbl, weights_to_update_tbl, and model_output_table)
      This elmiminates the UPDATE step, leaving only HOP and UDF steps
    
    - Add dist_key column to model_output table and DISTRIBUTE BY this instead
       of mst_key.  This removes Redistribute Motion from UDF query plan, so
       that weights only ever move during the hop query, not during the
       training query.
    
    - Simplified schedule rotation: schedule table created only once, then gets
      rotated on segments, instead of re-creating many times by transfering
      data back and forth from master to segments to master each hop.  No longer
      need separate "current_schedule" and "grand_schedule" data structures.
    
    - Skip first hop of each iteration
       (just rename model_output to model_input instead)
    
    - Split get_model_arch_and_weights() into query_weights() and get_model_arch()
        So we don't have to transfer weights from segment to master in places
        where we only need the model_arch json.
    
    - Much faster initialization code:  previously, we were reading the weights
      in from the original model output table (during warm start) and the model
      arch table (for transfer learning) one mst row at a time from segment to
      master, then writing them each back out one row at a time from master
      back to segments with a large number of SELECT and INSERT queries.
      Now, we just use a single query to copy the weights directly from the
      original model output table into the new model output table on the
      segments, without ever sending them to master.  And a similar single
      query copies the transfer learning weights directly from model_arch to
      model_output for training.  Both of these happen in parallel on the
      segments, instead of in sequence on master.  During testing on
      a 20-segment cluster with 20 models, this resulted in a 10x reduction
      in initialization time (26s instead of 5 mins)
    
    - Add some debugging that can be enabled to help profile the
      performance of fit multiple, and track which segment each mst_key
      is located during each hop. This also serves as an example for
      the utils/debug PR this is rebased on top of.
    
    - Add "unit" tests for fit mult model hopping code (implemented
      as dev-check tests so they can access the db)
    
    - Send Traceback of stack from segment back to coordinator
    
    - Cache plans for Hop & UDF queries
---
 .../modules/deep_learning/madlib_keras.py_in       | 239 ++++--
 .../modules/deep_learning/madlib_keras.sql_in      |  40 +-
 .../madlib_keras_automl_hyperband.py_in            |   1 +
 .../madlib_keras_automl_hyperopt.py_in             |   1 +
 .../madlib_keras_fit_multiple_model.py_in          | 949 +++++++++++++--------
 .../madlib_keras_fit_multiple_model.sql_in         |  90 +-
 .../deep_learning/madlib_keras_helper.py_in        |  17 +-
 .../deep_learning/madlib_keras_predict.py_in       |   1 -
 .../deep_learning/madlib_keras_serializer.py_in    |   1 +
 .../deep_learning/madlib_keras_wrapper.py_in       |  16 +-
 .../modules/deep_learning/model_arch_info.py_in    |  25 +-
 .../test/unit_tests/test_madlib_keras.py_in        | 136 ++-
 12 files changed, 905 insertions(+), 611 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index a3a8ae5..ba7f2b7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -27,6 +27,7 @@ from madlib_keras_helper import *
 from madlib_keras_validator import *
 from madlib_keras_wrapper import *
 from model_arch_info import *
+import tensorflow as tf
 
 from madlib_keras_model_selection import ModelSelectionSchema
 
@@ -42,6 +43,10 @@ from utilities.validate_args import quote_ident
 from utilities.control import MinWarning
 
 import tensorflow as tf
+import utilities.debug as DEBUG
+
+DEBUG.timings_enabled = False
+DEBUG.plpy_info_enabled = False
 
 from tensorflow.keras import backend as K
 from tensorflow.keras.layers import *
@@ -52,6 +57,7 @@ from tensorflow.keras.regularizers import *
 class GD_STORE:
     SESS = 'sess'
     SEGMENT_MODEL = 'segment_model'
+    AGG_IMAGE_COUNT = 'agg_image_count'
 
     @staticmethod
     def init(GD, sess, segment_model):
@@ -62,23 +68,27 @@ class GD_STORE:
     def clear(GD):
         del GD[GD_STORE.SEGMENT_MODEL]
         del GD[GD_STORE.SESS]
+        if GD_STORE.AGG_IMAGE_COUNT in GD:
+            del GD[GD_STORE.AGG_IMAGE_COUNT]
 
 def get_init_model_and_sess(GD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params, custom_function_map):
     # If a live session is present, re-use it. Otherwise, recreate it.
-    if GD_STORE.SESS in GD:
+
+    if GD_STORE.SESS in GD :
         if GD_STORE.SEGMENT_MODEL not in GD:
             plpy.error("Session and model should exist in GD after the first row"
-                       " of the first iteration")
-        sess = GD[GD_STORE.SESS]
-        segment_model = GD[GD_STORE.SEGMENT_MODEL]
-        K.set_session(sess)
+                       "of the first iteration")
+        with tf.device(device_name):
+            sess = GD[GD_STORE.SESS]
+            segment_model = GD[GD_STORE.SEGMENT_MODEL]
+            K.set_session(sess)
     else:
-        sess = get_keras_session(device_name, gpu_count, segments_per_host)
-        K.set_session(sess)
-        segment_model = init_model(model_architecture, compile_params, custom_function_map)
+        with tf.device(device_name):
+            sess = get_keras_session(device_name, gpu_count, segments_per_host)
+            K.set_session(sess)
+            segment_model = init_model(model_architecture, compile_params, custom_function_map)
         GD_STORE.init(GD, sess, segment_model)
-
     return segment_model, sess
 
 @MinWarning("warning")
@@ -118,7 +128,6 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
 
-
     warm_start = bool(warm_start)
 
     # The following two times must be recorded together.
@@ -140,12 +149,12 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
 
     serialized_weights = get_initial_weights(model, model_arch, model_weights,
-                                             warm_start, use_gpus, accessible_gpus_for_seg)
+                                             warm_start, accessible_gpus_for_seg)
     # Compute total images on each segment
     dist_key_mapping, images_per_seg_train = get_image_count_per_seg_for_minibatched_data_from_db(source_table)
 
     if validation_table:
-        seg_ids_val, images_per_seg_val = get_image_count_per_seg_for_minibatched_data_from_db(validation_table)
+        dist_key_mapping_val, images_per_seg_val = get_image_count_per_seg_for_minibatched_data_from_db(validation_table)
 
     # Construct validation dataset if provided
     validation_set_provided = bool(validation_table)
@@ -199,9 +208,29 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     for i in range(1, num_iterations+1):
         start_iteration = time.time()
         is_final_iteration = (i == num_iterations)
-        serialized_weights = plpy.execute(run_training_iteration,
-                                        [serialized_weights, custom_function_map]
-                                        )[0]['iteration_result']
+
+        try:
+            serialized_weights = plpy.execute(run_training_iteration,
+                                            [serialized_weights, custom_function_map]
+                                            )[0]['iteration_result']
+        except plpy.SPIError as e:
+            msg = e.message
+            if 'TransAggDetail' in msg:
+                e.message, detail = msg.split('TransAggDetail')
+            elif 'MergeAggDetail' in msg:
+                e.message, detail = msg.split('MergeAggDetail')
+            elif 'FinalAggDetail' in msg:
+                e.message, detail = msg.split('FinalAggDetail')
+            else:
+                raise e
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
         end_iteration = time.time()
         info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
             end_iteration - start_iteration)
@@ -240,7 +269,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
                                                            serialized_weights,
                                                            use_gpus,
                                                            accessible_gpus_for_seg,
-                                                           seg_ids_val,
+                                                           dist_key_mapping_val,
                                                            images_per_seg_val,
                                                            validation_metrics,
                                                            validation_loss,
@@ -376,7 +405,7 @@ def get_evaluate_info_msg(i, info_str, compute_out, is_train):
 
 
 def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
-                        use_gpus, accessible_gpus_for_seg, mst_filter=''):
+                        accessible_gpus_for_seg, mst_filter=''):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
@@ -391,12 +420,14 @@ def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
         will only be used for segment nodes.
         @args:
             @param model_table: Output model table passed in to fit.
-            @param model_arch_result: Dict containing model architecture info.
+            @param model_arch: Dict containing model architecture info.
             @param warm_start: Boolean flag indicating warm start or not.
     """
     if is_platform_pg():
+        # Use GPU's if they are enabled
         _ = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[0], None)
-    else:
+    else: # gpdb
+        # We are on master, so never use GPU's
         _ = get_device_name_and_set_cuda_env(0, None)
 
     if warm_start:
@@ -435,7 +466,7 @@ def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
                              serialized_weights, use_gpus,
                              accessible_gpus_for_seg, dist_key_mapping,
                              images_per_seg_val, metrics_list, loss_list,
-                             should_clear_session, custom_fn_name,
+                             should_clear_session, custom_fn_map,
                              model_table=None, mst_key=None):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
@@ -452,7 +483,7 @@ def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
                                                    dist_key_mapping,
                                                    images_per_seg_val,
                                                    should_clear_session,
-                                                   custom_fn_name,
+                                                   custom_fn_map,
                                                    model_table,
                                                    mst_key)
     end_val = time.time()
@@ -491,15 +522,6 @@ def init_model(model_architecture, compile_params, custom_function_map):
     compile_model(segment_model, compile_params, custom_function_map)
     return segment_model
 
-def update_model(segment_model, prev_serialized_weights):
-    """
-        Happens at first row of each iteration.
-    """
-    model_shapes = get_model_shapes(segment_model)
-    model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
-        prev_serialized_weights, model_shapes)
-    segment_model.set_weights(model_weights)
-
 def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                    independent_var_shape, model_architecture,
                    compile_params, fit_params, dist_key, dist_key_mapping,
@@ -520,21 +542,32 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape or not independent_var_shape\
+        or dependent_var is None or independent_var is None:
+            plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state
+
     GD = kwargs['GD']
+
+    trans_enter_time = time.time()
+
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
 
     segment_model, sess = get_init_model_and_sess(GD, device_name,
-                                                  accessible_gpus_for_seg[current_seg_id],
-                                                  segments_per_host,
-                                                  model_architecture, compile_params,
-                                                  custom_function_map)
-    if not state:
-        agg_image_count = 0
-        set_model_weights(segment_model, prev_serialized_weights)
+        accessible_gpus_for_seg[current_seg_id],
+        segments_per_host,
+        model_architecture, compile_params,
+        custom_function_map)
+
+    if GD_STORE.AGG_IMAGE_COUNT in GD:
+        agg_image_count = GD[GD_STORE.AGG_IMAGE_COUNT]
     else:
-        agg_image_count = float(state)
+        agg_image_count = 0
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+        with tf.device(device_name):
+            set_model_weights(segment_model, prev_serialized_weights)
 
     # Prepare the data
     x_train = np_array_float32(independent_var, independent_var_shape)
@@ -543,65 +576,76 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     # Fit segment model on data
     #TODO consider not doing this every time
     fit_params = parse_and_validate_fit_params(fit_params)
-    segment_model.fit(x_train, y_train, **fit_params)
+    with tf.device(device_name):
+        segment_model.fit(x_train, y_train, **fit_params)
 
     # Aggregating number of images, loss and accuracy
     agg_image_count += len(x_train)
+    GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
-    if is_multiple_model and is_last_row:
-        GD_STORE.clear(GD)
-        clear_keras_session(sess)
+
+    if is_last_row:
+        del GD[GD_STORE.AGG_IMAGE_COUNT]  # Must be reset after each pass through images
+        if is_multiple_model:
+            GD_STORE.clear(GD)
+            clear_keras_session(sess)
+
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_transition_time_|{}|".format(trans_exit_time - trans_enter_time))
 
     return return_state
 
-def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
+def fit_multiple_transition_caching(dependent_var, independent_var, dependent_var_shape,
                              independent_var_shape, model_architecture,
                              compile_params, fit_params, dist_key, dist_key_mapping,
                              current_seg_id, segments_per_host, images_per_seg,
-                             accessible_gpus_for_seg, prev_serialized_weights,
+                             accessible_gpus_for_seg, serialized_weights,
                              is_final_training_call, custom_function_map=None, **kwargs):
     """
     This transition function is called when caching is called for
     madlib_keras_fit_multiple_model().
-    The input params: dependent_var, independent_var are passed in
-    as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very first hop
+    The input params: dependent_var, independent_var,
+    dependent_var_shape and independent_var_shape are passed
+    in as None for all hops except the very first hop
     Some things to note in this function are:
-    - prev_serialized_weights can be passed in as None for the
-      very first hop and the final training call
+    - weights can be passed in as None for the very first hop
+      and the final training call.  (This can only happen if
+      num msts < num segs)
     - x_train, y_train and cache_set is cleared from GD for
-      final_training_call = TRUE
+      is_final_training_call = True
     """
-    if not state:
-        agg_image_count = 0
-    else:
-        agg_image_count = float(state)
-
     GD = kwargs['GD']
-    is_cache_set = 'cache_set' in GD
+
+    trans_enter_time = time.time()
+
+    if GD_STORE.AGG_IMAGE_COUNT in GD:
+        agg_image_count = GD[GD_STORE.AGG_IMAGE_COUNT]
+    else:
+        agg_image_count = 0
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
 
     # Prepare the data
-    if is_cache_set:
+    if not dependent_var_shape or not independent_var_shape \
+        or dependent_var is None or independent_var is None:
         if 'x_train' not in GD or 'y_train' not in GD:
             plpy.error("cache not populated properly.")
-        total_images = None
         is_last_row = True
+        total_images = None
     else:
-        if not independent_var or not dependent_var:
-            return state
-        if 'x_train' not in GD:
+        if 'x_train' not in GD or 'y_train' not in GD:
             GD['x_train'] = list()
             GD['y_train'] = list()
+
         agg_image_count += independent_var_shape[0]
-        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
-                                                          images_per_seg)
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+        total_images = get_image_count_per_seg_from_array(
+            dist_key_mapping.index(dist_key), images_per_seg
+        )
         is_last_row = agg_image_count == total_images
-        if is_last_row:
-            GD['cache_set'] = True
         x_train_current = np_array_float32(independent_var, independent_var_shape)
         y_train_current = np_array_int16(dependent_var, dependent_var_shape)
         GD['x_train'].append(x_train_current)
@@ -609,15 +653,16 @@ def fit_multiple_transition_caching(state, dependent_var, independent_var, depen
 
     # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
     # But if the weights are None, we do not want to set any model. So early return in that case
-    if prev_serialized_weights is None:
+    if serialized_weights is None:
         if is_final_training_call:
+            del GD[GD_STORE.AGG_IMAGE_COUNT]
             del GD['x_train']
             del GD['y_train']
-            del GD['cache_set']
-        return float(agg_image_count)
+        return None
 
     segment_model = None
     sess = None
+
     if is_last_row:
         device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
         segment_model, sess = get_init_model_and_sess(GD, device_name,
@@ -625,29 +670,34 @@ def fit_multiple_transition_caching(state, dependent_var, independent_var, depen
                                                       segments_per_host,
                                                       model_architecture, compile_params,
                                                       custom_function_map)
-        set_model_weights(segment_model, prev_serialized_weights)
 
-        fit_params = parse_and_validate_fit_params(fit_params)
-        for i in range(len(GD['x_train'])):
-            # Fit segment model on data
-            segment_model.fit(GD['x_train'][i], GD['y_train'][i], **fit_params)
+        with tf.device(device_name):
+            set_model_weights(segment_model, serialized_weights)
+            fit_params = parse_and_validate_fit_params(fit_params)
 
+            for i in range(len(GD['x_train'])):
+                # Fit segment model on data
+                segment_model.fit(GD['x_train'][i], GD['y_train'][i], **fit_params)
 
     return_state = get_state_to_return(segment_model, is_last_row, True,
-                                       agg_image_count, total_images)
+                                       agg_image_count)
 
     if is_last_row:
         GD_STORE.clear(GD)
         clear_keras_session(sess)
         if is_final_training_call:
+            if GD_STORE.AGG_IMAGE_COUNT in GD:
+                del GD[GD_STORE.AGG_IMAGE_COUNT]
             del GD['x_train']
             del GD['y_train']
-            del GD['cache_set']
+
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_multiple_transition_caching_time_|{}|".format(trans_exit_time - trans_enter_time))
 
     return return_state
 
 def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image_count,
-                        total_images):
+                        total_images=None):
     """
     1. For both model averaging fit_transition and fit multiple transition, the
     state only needs to have the image count except for the last row.
@@ -663,17 +713,20 @@ def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image
     :param is_last_row: boolean to indicate if last row for that hop
     :param is_multiple_model: boolean
     :param agg_image_count: aggregated image count per hop
-    :param total_images: total images per segment
+    :param total_images: total images per segment (only used for madlib_keras_fit() )
     :return:
     """
-    if is_last_row:
-        updated_model_weights = segment_model.get_weights()
-        if is_multiple_model:
+    if is_multiple_model:
+        if is_last_row:
+            updated_model_weights = segment_model.get_weights()
             new_state = madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
         else:
-            updated_model_weights = [total_images * w for w in updated_model_weights]
-            new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
-                agg_image_count, updated_model_weights)
+            new_state = None
+    elif is_last_row:
+        updated_model_weights = segment_model.get_weights()
+        updated_model_weights = [total_images * w for w in updated_model_weights]
+        new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
+            agg_image_count, updated_model_weights)
     else:
         new_state = float(agg_image_count)
 
@@ -808,8 +861,12 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                     accessible_gpus_for_seg, dist_key_mapping, images_per_seg,
                                     should_clear_session=True, custom_function_map=None,
                                     model_table=None, mst_key=None):
+    """
+    This function will call the internal keras evaluate function to get the loss
+    and accuracy of each tuple which then gets averaged to get the final result.
+    """
 
-    dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
+    dist_key_col = '0' if is_platform_pg() else '__table__.{0}'.format(DISTRIBUTION_KEY_COLNAME)
     gp_segment_id_col = '0' if is_platform_pg() else '__table__.{0}'.format(GP_SEGMENT_ID_COLNAME)
     segments_per_host = get_segments_per_host()
 
@@ -820,10 +877,7 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
         MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, "_shape")
     ind_shape_col = add_postfix(
         MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
-    """
-    This function will call the internal keras evaluate function to get the loss
-    and accuracy of each tuple which then gets averaged to get the final result.
-    """
+
     use_gpus = use_gpus if use_gpus else False
 
     eval_sql = """
@@ -861,9 +915,12 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
         evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea", "bytea"])
         res = plpy.execute(evaluate_query, [serialized_weights, custom_function_map])
 
-    loss_metric = res[0]['loss_metric']
-    return loss_metric
 
+    if res is None:
+        plpy.error("Zero rows returned from evaluate query: {}".format(evaluate_query))
+    else:
+        loss_metric = res[0]['loss_metric']
+    return loss_metric
 
 def internal_keras_eval_transition(state, dependent_var, independent_var,
                                    dependent_var_shape, independent_var_shape,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index ff00fa6..e0e0fb5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1797,7 +1797,17 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     custom_function_map         BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_transition(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'TransAggDetail' + detail
+        e.args = (message,)
+        raise e
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1806,7 +1816,18 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_merge(
     state2          BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_merge(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+
+    try:
+        return madlib_keras.fit_merge(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'MergeAggDetail' + detail
+        e.args = (message,)
+        raise e
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1814,7 +1835,18 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_final(
     state BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_final(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_final(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'FinalAggDetail' + detail
+        e.args = (message,)
+        raise e
+
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1850,7 +1882,7 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* segments_per_host */      INTEGER,
     /* images_per_seg */         INTEGER[],
     /* segments_per_host  */     INTEGER[],
-    /* serialized_weights */     BYTEA,
+    /* prev_serialized_weights */BYTEA,
     /* custom_loss_cfunction */  BYTEA
 )(
     STYPE=BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in
index 2567b42..d44c3ea 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in
@@ -253,6 +253,7 @@ class AutoMLHyperband(KerasAutoML):
                 model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLConstants.MODEL_OUTPUT_TABLE,
                                                 AutoMLConstants.MST_TABLE, num_iterations, self.use_gpus,
                                                 self.validation_table, mcf, self.warm_start, self.name, self.description)
+                model_training.fit_multiple_model()
             self.update_model_output_table()
             self.update_model_output_info_table(i, initial_vals)
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in
index 9825f76..b852e14 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in
@@ -161,6 +161,7 @@ class AutoMLHyperopt(KerasAutoML):
                                                   AutoMLConstants.MST_TABLE, self.num_iters, self.use_gpus, self.validation_table,
                                                   self.metrics_compute_frequency, False, self.name, self.description,
                                                   metrics_elapsed_time_offset=metrics_elapsed_time_offset)
+                model_training.fit_multiple_model()
             metrics_elapsed_time_offset += time.time() - start_time
             if make_mst_summary:
                 self.generate_mst_summary_table(self.model_selection_summary_table)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index a03f6cb..182a7a1 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -20,18 +20,22 @@
 import plpy
 import time
 import sys
+import json
+import random
+import datetime
+from collections import defaultdict
+# from tensorflow.keras.models import *
 
 from madlib_keras import compute_loss_and_metrics
-from madlib_keras import get_initial_weights
-from madlib_keras import get_model_arch_weights
+from madlib_keras import get_model_arch
 from madlib_keras import get_source_summary_table_dict
 from madlib_keras import should_compute_metrics_this_iter
+from madlib_keras import get_initial_weights
 from madlib_keras_helper import *
 from madlib_keras_model_selection import ModelSelectionSchema
 from madlib_keras_validator import *
 from madlib_keras_wrapper import *
 
-from utilities.control import MinWarning
 from utilities.control import OptimizerControl
 from utilities.control import SetGUC
 from utilities.utilities import add_postfix
@@ -43,16 +47,17 @@ from utilities.utilities import is_platform_pg
 from utilities.utilities import get_seg_number
 from utilities.utilities import get_segments_per_host
 from utilities.utilities import rename_table
+import utilities.debug as DEBUG
+from utilities.debug import plpy_prepare
+from utilities.debug import plpy_execute
 
-import json
-from collections import defaultdict
-import random
-import datetime
+DEBUG.timings_enabled = False
+DEBUG.mst_keys_enabled = False
+DEBUG.plpy_execute_enabled = False
+DEBUG.plpy_info_enabled = False
 
-from tensorflow.keras.models import *
 mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
 mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
-dist_key_col = DISTRIBUTION_KEY_COLNAME
 
 """
 FitMultipleModel: This class implements the Model Hopper technique for
@@ -76,8 +81,7 @@ segment.
 Note that this function is disabled for Postgres.
 """
 
-@MinWarning("warning")
-class FitMultipleModel():
+class FitMultipleModel(object):
     def __init__(self, schema_madlib, source_table, model_output_table,
                  model_selection_table, num_iterations,
                  use_gpus=False, validation_table=None,
@@ -113,6 +117,8 @@ class FitMultipleModel():
         if self.model_selection_table:
             self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')
 
+        self.dist_key_col = DISTRIBUTION_KEY_COLNAME
+        self.prev_dist_key_col = '__prev_dist_key__'
         self.num_iterations = num_iterations
         self.metrics_compute_frequency = metrics_compute_frequency
         self.name = name
@@ -134,57 +140,56 @@ class FitMultipleModel():
         self.info_str = ""
         self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
-        self.use_gpus = use_gpus
+        self.use_gpus = use_gpus if use_gpus else False
         self.segments_per_host = get_segments_per_host()
+        self.model_input_tbl = unique_string('model_input')
+        self.model_output_tbl = unique_string('model_output')
+        self.schedule_tbl = unique_string('schedule')
+        self.next_schedule_tbl = unique_string('next_schedule')
         self.cached_source_table = unique_string('cached_source_table')
         self.metrics_elapsed_time_offset = metrics_elapsed_time_offset
+        self.rotate_schedule_tbl_plan = self.add_object_maps_plan = None
+        self.hop_plan = self.udf_plan = None
+
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                 self.schema_madlib, self.segments_per_host, self.module_name)
         else:
             self.accessible_gpus_for_seg = get_seg_number()*[0]
 
-        self.original_model_output_table = model_output_table
-        if self.original_model_output_table:
-            self.model_info_table = add_postfix(self.original_model_output_table, '_info')
-            self.model_summary_table = add_postfix(
-                self.original_model_output_table, '_summary')
+        self.original_model_output_tbl = model_output_table
+        if not self.original_model_output_tbl:
+	    plpy.error("Must specify an output table.")
 
-        self.model_output_table = self.original_model_output_table
+        self.model_info_tbl = add_postfix(
+            self.original_model_output_tbl, '_info')
+        self.model_summary_table = add_postfix(
+            self.original_model_output_tbl, '_summary')
 
-        """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
-        """
         self.warm_start = bool(warm_start)
-        self.warm_start_msts = []
-        if self.warm_start:
-            self.model_output_table = unique_string('initial_model')
 
         self.fit_validator_train = FitMultipleInputValidator(
-            self.source_table, self.validation_table, self.original_model_output_table,
+            self.source_table, self.validation_table, self.original_model_output_tbl,
             self.model_selection_table, self.model_selection_summary_table,
             mb_dep_var_col, mb_indep_var_col, self.num_iterations,
-            self.model_info_table, self.mst_key_col, self.model_arch_table_col,
+            self.model_info_tbl, self.mst_key_col, self.model_arch_table_col,
             self.metrics_compute_frequency, self.warm_start, self.use_gpus,
             self.accessible_gpus_for_seg)
         if self.metrics_compute_frequency is None:
             self.metrics_compute_frequency = num_iterations
 
-
         self.msts = self.fit_validator_train.msts
         self.model_arch_table = self.fit_validator_train.model_arch_table
         self.object_table = self.fit_validator_train.object_table
         self.metrics_iters = []
         self.object_map_col = 'object_map'
+        self.custom_mst_keys = None
         if self.object_table is not None:
             self.populate_object_map()
 
-        original_cuda_env = None
+        self.original_cuda_env = None
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
-            original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
+            self.original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
         self.dist_key_mapping, self.images_per_seg_train = \
             get_image_count_per_seg_for_minibatched_data_from_db(
@@ -197,36 +202,48 @@ class FitMultipleModel():
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        self.extra_dist_keys = []
+
+        num_msts = self.num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = GP_SEGMENT_ID_COLNAME
-        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
+        # Ordered list of sql representations of each mst_key,
+        #  including NULL's.  This will be used to pass the mst keys
+        #  to the db as a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
 
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
 
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
+        self.gp_segment_id_col = GP_SEGMENT_ID_COLNAME
+        self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
@@ -234,35 +251,54 @@ class FitMultipleModel():
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
+                                    CREATE TABLE {self.original_model_output_tbl} AS
+                                    SELECT
+                                        {self.mst_key_col}::INTEGER,
+                                        {self.model_weights_col}::BYTEA,
+                                        {self.model_arch_col}::JSON,
+                                        {self.dist_key_col}::INTEGER
+                                    FROM {self.model_output_tbl}
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(final_output_table_create_query)
+        self.truncate_and_drop(self.model_output_tbl)
 
     def train_multiple_model(self):
-        total_msts = len(self.msts_for_schedule)
+        total_msts = len(self.all_mst_keys)
+        DEBUG.start_timing('train_multiple_model_extra')
+
         for iter in range(1, self.num_iterations+1):
-            for mst_idx in range(total_msts):
-                mst_row = [self.grand_schedule[dist_key][mst_idx]
-                           for dist_key in self.dist_keys]
-                self.create_mst_schedule_table(mst_row)
-                self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1)
-                if mst_idx == 0:
+            for hop in range(total_msts):
+                self.is_final_training_call = (iter == self.num_iterations and hop == total_msts-1)
+                if hop == 0:
                     start_iteration = time.time()
-                self.run_training(mst_idx, mst_idx==0 and iter==1)
-                if mst_idx == (total_msts - 1):
+
+                self.run_training(hop, hop==0 and iter==1)
+                DEBUG.start_timing('train_multiple_model_extra')
+
+                if hop == (total_msts - 1):
                     end_iteration = time.time()
                     self.info_str = "\tTime for training in iteration " \
                                     "{0}: {1} sec\n".format(iter,
                                                             end_iteration -
                                                             start_iteration)
+                else:
+                    self.rotate_schedule_tbl()
+
             if should_compute_metrics_this_iter(iter,
                                                 self.metrics_compute_frequency,
                                                 self.num_iterations):
@@ -272,9 +308,12 @@ class FitMultipleModel():
                 if self.validation_table:
                     self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
-        plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self))
-
+        plpy.execute("DROP TABLE IF EXISTS {self.schedule_tbl}".format(self=self))
+        if self.use_caching:
+            plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table}".format(self=self))
+ 
     def evaluate_model(self, epoch, table, is_train):
+        DEBUG.start_timing('eval_model_total')
         if is_train:
             mst_metric_eval_time = self.train_mst_metric_eval_time
             mst_loss = self.train_mst_loss
@@ -289,7 +328,8 @@ class FitMultipleModel():
             images_per_seg = self.images_per_seg_valid
             self.info_str += "\n\tValidation set after iteration {0}:".format(epoch)
         for mst in self.msts:
-            model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
+            model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
+            DEBUG.start_timing('eval_compute_loss_and_metrics')
             _, metric, loss = compute_loss_and_metrics(
                 self.schema_madlib, table, "$madlib${0}$madlib$".format(
                     mst[self.compile_params_col]),
@@ -301,33 +341,28 @@ class FitMultipleModel():
                 images_per_seg,
                 [], [], True,
                 mst[self.object_map_col],
-                self.model_output_table,
+                self.model_output_tbl,
                 mst[self.mst_key_col])
+            DEBUG.print_timing('eval_compute_loss_and_metrics')
             mst_metric_eval_time[mst[self.mst_key_col]] \
                 .append(self.metrics_elapsed_time_offset + (time.time() - self.metrics_elapsed_start_time))
             mst_loss[mst[self.mst_key_col]].append(loss)
             mst_metric[mst[self.mst_key_col]].append(metric)
             self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(mst[self.mst_key_col], metric, loss)
-
-    def generate_schedule(self, msts):
-        """ Generate the schedule for models hopping to segments """
-        grand_schedule = {}
-        for index, dist_key in enumerate(self.dist_keys):
-            grand_schedule[dist_key] = rotate(msts, index)
-        return grand_schedule
+        DEBUG.print_timing('eval_model_total')
 
     def populate_object_map(self):
         builtin_losses = dir(losses)
         builtin_metrics = update_builtin_metrics(dir(metrics))
 
         # Track distinct custom functions in compile_params
-        custom_fn_names = []
+        custom_fn_names = set()
         # Track their corresponding mst_keys to pass along the custom function
         # definition read from the object table.
         # For compile_params calling builtin functions the object_map is set to
         # None.
-        custom_fn_mst_idx = []
-        for mst, mst_idx in zip(self.msts, range(len(self.msts))):
+        custom_msts = []
+        for mst in self.msts:
             compile_params = mst[self.compile_params_col]
             # We assume that the compile_param is validated as part
             # of the loading mst_table and thus not validating here
@@ -338,183 +373,299 @@ class FitMultipleModel():
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
-
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
-
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
-
-    def create_model_output_table(self):
-        output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                     {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
-                                    """.format(self=self)
-        plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
+
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
+
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
+
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
 
-    def create_model_output_table_warm_start(self):
+    def init_schedule_tbl(self):
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE {self.unlogged_table} TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        plpy_execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if self.rotate_schedule_tbl_plan is None:
+            rotate_schedule_tbl_query = """
+                CREATE {self.unlogged_table} TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl}
+                DISTRIBUTED BY ({self.prev_dist_key_col})
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
         """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
         """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        plpy_execute(load_warm_start_weights_query)
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        plpy_execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
-
-    def initialize_model_output_and_info(self):
+        output_table_create_query = """
+                                    CREATE {self.unlogged_table} TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.dist_key_col})
+                                    )
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(output_table_create_query)
+
+        if self.warm_start:
+            self.load_warm_start_weights()
+        else:  # Note:  We only support xfer learning when warm_start=False
+            self.load_xfer_learning_weights()
+
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+        """.format(self=self))
+       
+        if res:
+            initialized_msts = set([ row['mst_keys'] for row in res ])
+        else:
+            initialized_msts = set()
+
+        # We've already bulk loaded all of the models with user-specified weights.
+        #  For the rest of the models, we need to generate the weights for each
+        #  by initializing them with keras and adding them one row at a time.
+        #
+        # TODO:  In the future, we should probably move the weight initialization
+        #  into the transition function on the segments.  Here, we would just
+        #  bulk load everything with a single query (or 2, for the warm start case),
+        #  and leave the weights column as NULL for any model whose weights need
+        #  to be randomly initialized.  Then in fit_transition, if prev_weights is
+        #  NULL, and there is nothing in GD, it should just skip the call to
+        #  set_weights(), and keras will automatically initialize them during
+        #  model.from_json(model_arch).
+        #
+        #  This would be a very easy change for fit_multiple(), but might require
+        #   some more work to support fit().  All of the segments there need to
+        #   start with the same weights, so we'd at least have to pass a random
+        #   seed to the transition function for keras to use.  Or generate a seed
+        #   on the segments in some deterministic way that's the same for all.
+        for index, mst in enumerate(self.msts_for_schedule):
+            if mst is None:
+                continue
+
+            if mst['mst_key'] in initialized_msts:
+                continue  # skip if we've already loaded this mst
+
+            num_dist_keys = len(self.dist_keys)
+
+            if index < num_dist_keys:
+                dist_key = self.dist_keys[index]
+            else:  # For models that won't be trained on first hop
+                dist_key = self.extra_dist_keys[index - num_dist_keys]
+
+            model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
+            serialized_weights = get_initial_weights(None, model_arch, None, False,
+                                                     self.accessible_gpus_for_seg)
+
+            output_table_add_row_query = """
+                INSERT INTO {self.model_output_tbl} (
+                    {self.mst_key_col},
+                    {self.model_weights_col},
+                    {self.model_arch_col},
+                    {self.compile_params_col},
+                    {self.fit_params_col},
+                    {self.object_map_col},
+                    {self.dist_key_col}
+                ) VALUES (
+                    $MADLIB${{{self.mst_key_col}}}$MADLIB$,
+                    $1,
+                    $2,
+                    $MADLIB${{{self.compile_params_col}}}$MADLIB$,
+                    $MADLIB${{{self.fit_params_col}}}$MADLIB$,
+                    NULL, -- Fill in custom object_map soon
+                    $3
+                )
+            """.format(self=self).format(**mst)
+
+            output_table_add_row_query_prepared = plpy.prepare(
+                output_table_add_row_query,
+                ["BYTEA", "JSON", "INTEGER"]
+            )
+
+            plpy.execute(output_table_add_row_query_prepared,
+                [ serialized_weights, model_arch, dist_key ]
+            )
+
+        if self.custom_mst_keys:
+            custom_keys = '({})'.format(
+                ','.join( map(str, self.custom_mst_keys) )
+            )
+
+            # Add object_map to any msts which use custom functions
+            if self.add_object_maps_plan is None:
+                self.add_object_maps_plan = plpy.prepare("""
+                    UPDATE {self.model_output_tbl}
+                        SET {self.object_map_col} = $1
+                            WHERE {self.mst_key_col} IN {custom_keys}
+                """.format(**locals()), ["BYTEA"])
+            plpy.execute(self.add_object_maps_plan, [self.custom_fn_object_map])
+
+    def init_model_info_tbl(self):
         info_table_create_query = """
-                                  CREATE TABLE {self.model_info_table}
-                                  ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                   {self.model_id_col} INTEGER,
-                                   {self.compile_params_col} TEXT,
-                                   {self.fit_params_col} TEXT,
-                                   model_type TEXT,
-                                   model_size DOUBLE PRECISION,
-                                   metrics_elapsed_time DOUBLE PRECISION[],
-                                   metrics_type TEXT[],
-                                   loss_type TEXT,
-                                   training_metrics_final DOUBLE PRECISION,
-                                   training_loss_final DOUBLE PRECISION,
-                                   training_metrics DOUBLE PRECISION[],
-                                   training_loss DOUBLE PRECISION[],
-                                   validation_metrics_final DOUBLE PRECISION,
-                                   validation_loss_final DOUBLE PRECISION,
-                                   validation_metrics DOUBLE PRECISION[],
-                                   validation_loss DOUBLE PRECISION[])
-                                       """.format(self=self)
+            DROP TABLE IF EXISTS {self.model_info_tbl};
+            CREATE TABLE {self.model_info_tbl} (
+                {self.mst_key_col} INTEGER PRIMARY KEY,
+                {self.model_id_col} INTEGER,
+                {self.compile_params_col} TEXT,
+                {self.fit_params_col} TEXT,
+                model_type TEXT,
+                model_size DOUBLE PRECISION,
+                metrics_elapsed_time DOUBLE PRECISION[],
+                metrics_type TEXT[],
+                loss_type TEXT,
+                training_metrics_final DOUBLE PRECISION,
+                training_loss_final DOUBLE PRECISION,
+                training_metrics DOUBLE PRECISION[],
+                training_loss DOUBLE PRECISION[],
+                validation_metrics_final DOUBLE PRECISION,
+                validation_loss_final DOUBLE PRECISION,
+                validation_metrics DOUBLE PRECISION[],
+                validation_loss DOUBLE PRECISION[]
+           ) """.format(self=self)
 
         plpy.execute(info_table_create_query)
-        for mst in self.msts:
-            model_arch, model_weights = get_model_arch_weights(self.model_arch_table,
-                                                               mst[self.model_id_col])
-
-
-            # If warm start is enabled, weights from transfer learning cannot be
-            # used, even if a particular model doesn't have warm start weights.
-            if self.warm_start:
-                model_weights = None
-                mst_filter = """
-                            WHERE {mst_col}={mst_key}
-                        """.format(
-                    mst_col=self.mst_key_col,
-                    mst_key=mst['mst_key']
-                )
 
-            else:
-                mst_filter = ''
-
-            serialized_weights = get_initial_weights(self.model_output_table,
-                                                     model_arch,
-                                                     model_weights,
-                                                     mst['mst_key'] in self.warm_start_msts,
-                                                     self.use_gpus,
-                                                     self.accessible_gpus_for_seg,
-                                                     mst_filter
-                                                     )
-            model_size = sys.getsizeof(serialized_weights) / 1024.0
+        info_table_insert_query = """
+            INSERT INTO {self.model_info_tbl} (
+                {self.mst_key_col},
+                {self.model_id_col},
+                {self.compile_params_col},
+                {self.fit_params_col},
+                model_type,
+                model_size
+            )
+            SELECT
+                m.{self.mst_key_col},
+                m.{self.model_id_col},
+                m.{self.compile_params_col},
+                m.{self.fit_params_col},
+                '{model_type}',
+                LENGTH(o.{self.model_weights_col})/1024.0
+            FROM {self.model_selection_table} m JOIN {self.model_output_tbl} o
+                USING ({self.mst_key_col})
+        """.format(self=self,
+                   model_type='madlib_keras')
+
+        plpy.execute(info_table_insert_query)
+
+        for mst in self.msts_for_schedule:
+            if mst is None:
+                continue
 
             metrics_list = get_metrics_from_compile_param(
                 mst[self.compile_params_col])
-            is_metrics_specified = True if metrics_list else False
             metrics_type = 'ARRAY{0}'.format(
-                metrics_list) if is_metrics_specified else 'NULL'
-
+                metrics_list) if metrics_list else 'NULL'
             loss_type = get_loss_from_compile_param(mst[self.compile_params_col])
             loss_type = loss_type if loss_type else 'NULL'
 
-            info_table_insert_query = """
-                            INSERT INTO {self.model_info_table}({self.mst_key_col},
-                                        {self.model_id_col}, {self.compile_params_col},
-                                        {self.fit_params_col}, model_type, model_size,
-                                        metrics_type, loss_type)
-                                VALUES ({mst_key_val}, {model_id},
-                                        $madlib${compile_params}$madlib$,
-                                        $madlib${fit_params}$madlib$, '{model_type}',
-                                        {model_size}, {metrics_type}, '{loss_type}')
-                        """.format(self=self,
-                                   mst_key_val=mst[self.mst_key_col],
-                                   model_id=mst[self.model_id_col],
-                                   compile_params=mst[self.compile_params_col],
-                                   fit_params=mst[self.fit_params_col],
-                                   model_type='madlib_keras',
-                                   model_size=model_size,
-                                   metrics_type=metrics_type,
-                                   loss_type=loss_type)
-            plpy.execute(info_table_insert_query)
-
-            if not mst['mst_key'] in self.warm_start_msts:
-                output_table_insert_query = """
-                                    INSERT INTO {self.model_output_table}(
-                                        {self.mst_key_col}, {self.model_weights_col},
-                                        {self.model_arch_col})
-                                    VALUES ({mst_key}, $1, $2)
-                                       """.format(self=self,
-                                                  mst_key=mst[self.mst_key_col])
-                output_table_insert_query_prepared = plpy.prepare(
-                    output_table_insert_query, ["bytea", "json"])
-                plpy.execute(output_table_insert_query_prepared, [
-                             serialized_weights, model_arch])
+            plpy.execute("""
+                UPDATE {self.model_info_tbl} SET
+                    metrics_type = {metrics_type},
+                    loss_type = '{loss_type}'
+                WHERE {self.mst_key_col} = {{{self.mst_key_col}}}
+            """.format(self=self,
+                       metrics_type=metrics_type,
+                       loss_type=loss_type
+              ).format(**mst))
+
+        DEBUG.print_timing('init_model_output_and_info')
 
     def create_model_summary_table(self):
         if self.warm_start:
@@ -548,8 +699,8 @@ class FitMultipleModel():
                 SELECT
                     $MAD${self.source_table}$MAD$::TEXT AS source_table,
                     {self.validation_table}::TEXT AS validation_table,
-                    $MAD${self.model_output_table}$MAD$::TEXT AS model,
-                    $MAD${self.model_info_table}$MAD$::TEXT AS model_info,
+                    $MAD${self.original_model_output_tbl}$MAD$::TEXT AS model,
+                    $MAD${self.model_info_tbl}$MAD$::TEXT AS model_info,
                     $MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
                     $MAD${independent_varname}$MAD$::TEXT AS independent_varname,
                     $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
@@ -592,7 +743,7 @@ class FitMultipleModel():
 
         if is_train:
             update_query = """
-                           UPDATE {self.model_info_table} SET
+                           UPDATE {self.model_info_tbl} SET
                            training_metrics_final = {metrics_final},
                            training_loss_final = {loss_final},
                            metrics_elapsed_time = {metrics_elapsed_time},
@@ -602,7 +753,7 @@ class FitMultipleModel():
                            """.format(**locals())
         else:
             update_query = """
-                           UPDATE {self.model_info_table} SET
+                           UPDATE {self.model_info_tbl} SET
                            validation_metrics_final = {metrics_final},
                            validation_loss_final = {loss_final},
                            metrics_elapsed_time = {metrics_elapsed_time},
@@ -617,8 +768,43 @@ class FitMultipleModel():
             self.update_info_table(mst, True)
             if self.validation_table:
                 self.update_info_table(mst, False)
-
-    def run_training(self, mst_idx, is_very_first_hop):
+   
+    def run_training(self, hop, is_very_first_hop):
+        """
+               This method is called once per hop from the main fit_multiple_model loop.
+            The hop param here identifies the hop number within an iteration, starting
+            over each iteration at hop 0.  It ranges from 0 to the greater of either
+            the number of model configs in the mst table or the number of segments with
+            data on them.  This ensures that each model config gets paired with each
+            data segment exactly once per iteration.
+
+               If there are more segments than model configs, then there will be some
+            NULL mst_key rows in the model_input & model_output tables.  If instead there
+            are more mst keys than segments, then the models not being trained this round
+            will have "extra" dist keys, meaning dist_key > max_dist_key where max_dist_key
+            is the largest dist key in the source table.  Each of these will be distributed
+            on some segment, but we don't care which.
+
+            There are 2 main tasks performed in run_training():
+                1.)  The actual hop - each of the rows in the model_output table from the
+                     previous round are permuted onto the next segment in a round-robin
+                     fashion... the result is saved as the model_input table for this round.
+                     The bulk of the data in each row is the model weights.  The schedule
+                     table is there to guides each of these models from their previous location
+                     to their new scheduled location, where they will train this round.
+
+                2.)  Calling fit_transition_multiple_model() - We join the model_input
+                     table with the data source table to train the models on the data local
+                     to their segment.  The most important concern here is making sure that
+                     the plan for this query does not redistribute any of the model weights.
+                     The dist keys are carefully chosen so that there should be no data
+                     movement--the only time the model weights move is during the actual
+                     hop.  Without caching, the models are trained one row at a time,
+                     conceptually similar to a UDA.  With caching enabled, all of the
+                     rows are combined in memory on the very first round.  So after that
+                     we replace the source table with an empty table (cached_source_table),
+                     containing only 1 row per segment, with dist keys but no actual data.
+        """
         # NOTE: In the DL module, we want to avoid CREATING TEMP tables
         # (creates a slice which stays until the session is disconnected)
         # or minimize writing queries that generate plans with Motions (creating
@@ -630,116 +816,170 @@ class FitMultipleModel():
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+
+            if self.hop_plan is None:
+                self.hop_plan = plpy_prepare("""
+                    CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                        SELECT o.{self.mst_key_col},
+                               o.{self.model_weights_col},
+                               o.{self.model_arch_col},
+                               o.{self.compile_params_col},
+                               o.{self.fit_params_col},
+                               o.{self.object_map_col},
+                               s.{self.dist_key_col}
+                        FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                            ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                        DISTRIBUTED BY ({self.dist_key_col})
+                    """.format(self=self)
+                )
+
+            plpy_execute(self.hop_plan)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            self.truncate_and_drop(self.model_output_tbl)
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
         ind_shape_col = self.ind_shape_col
-        dep_var = mb_dep_var_col
-        indep_var = mb_indep_var_col
+        dep_shape_col = self.dep_shape_col
+        dep_var_col = mb_dep_var_col
+        indep_var_col = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE {self.unlogged_table} TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape_col = ind_shape_col = 'NULL'
+                dep_var_col = indep_var_col = 'NULL'
                 source_table = self.cached_source_table
-            if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
-
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
-                       use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
-                       source_table=source_table,
-                       where_clause=where_clause,
-                       self=self
-                       )
-        plpy.execute(uda_query)
-
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
-
-        self.truncate_and_drop_tables()
 
-    def truncate_and_drop_tables(self):
+            if is_very_first_hop or self.is_final_training_call:
+                num_msts = self.num_msts
+                num_segs = len(self.dist_keys)
+                if num_msts < num_segs:
+                    # Add some empty rows, so that cache gets
+                    #  populated or deleted on all segments, not
+                    #  just those with models on them currently.
+                    insert_empty_rows_query = """
+                        INSERT INTO {self.model_input_tbl} (__dist_key__)
+                            SELECT __dist_key__ FROM {self.schedule_tbl}
+                                WHERE {self.mst_key_col} IS NULL
+                    """.format(self=self)
+                    plpy_execute(insert_empty_rows_query)
+
+        DEBUG.start_timing("udf")
+        if self.udf_plan is None:
+            self.udf_plan = plpy_prepare("""
+                CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+                SELECT
+                    model_in.{self.mst_key_col},
+                    CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                    THEN
+                        model_in.{self.model_weights_col}
+                    ELSE
+                        {self.schema_madlib}.fit_transition_multiple_model(
+                            {dep_var_col},
+                            {indep_var_col},
+                            {dep_shape_col},
+                            {ind_shape_col},
+                            model_in.{self.model_arch_col}::TEXT,
+                            model_in.{self.compile_params_col}::TEXT,
+                            model_in.{self.fit_params_col}::TEXT,
+                            src.{self.dist_key_col},
+                            ARRAY{self.dist_key_mapping},
+                            src.{self.gp_segment_id_col},
+                            {self.segments_per_host},
+                            ARRAY{self.images_per_seg_train},
+                            ARRAY{self.accessible_gpus_for_seg},
+                            model_in.{self.model_weights_col}::BYTEA,
+                            $1::BOOLEAN, -- is_final_training_call
+                            {self.use_caching}::BOOLEAN,
+                            model_in.{self.object_map_col}::BYTEA
+                        )
+                    END::BYTEA AS {self.model_weights_col},
+                    model_in.{self.model_arch_col},
+                    model_in.{self.compile_params_col},
+                    model_in.{self.fit_params_col},
+                    model_in.{self.object_map_col},
+                    model_in.{self.dist_key_col}
+                FROM {self.model_input_tbl} model_in
+                    LEFT JOIN {source_table} src
+                    USING ({self.dist_key_col}) 
+                DISTRIBUTED BY({self.dist_key_col})
+                """.format(dep_var_col=dep_var_col,
+                           indep_var_col=indep_var_col,
+                           dep_shape_col=dep_shape_col,
+                           ind_shape_col=ind_shape_col,
+                           source_table=source_table,
+                           self=self
+                           ),
+                [ 'BOOLEAN' ]
+            )
+
+        try:
+            plpy_execute(self.udf_plan, [ self.is_final_training_call ] )
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'UDF_Detail' in msg:
+                raise e
+            e.message, detail = msg.split('UDF_Detail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
+        DEBUG.print_timing("udf")
+
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))
+
+        self.truncate_and_drop(self.model_input_tbl)
+
+        if self.use_caching and is_very_first_hop:
+            # Throw away plan for source_table, force generation of a new one
+            #  next time for cached_source_table
+            self.udf_plan = None
+
+        DEBUG.print_timing("run_training")
+
+    def truncate_and_drop(self, table):
         """
-        Context: UPDATE statements in postgres are not in-place replacements but
-        the row to be updated is marked for deletion(note that the disk space for
-        this row doesn't get released until vaccuum is called) and a new row in
-        inserted.
-
-        This function will clear out the disk space used by the model_output_table
-        and also drop all the other intermediate tables.
-        If available, set the `` guc so that the truncate command can release the
-        disk space. The disk space will be released immediately and hence the
-        model_output table won't grow in size with each UPDATE statement.
+        This function truncates and drops one of the intermediate tables used
+        during an iteration (model_input_tbl, model_output_tbl, schedule_tbl).
+        If available, set the `dev_opt_unsafe_truncate_in_subtransaction` guc 
+        so that the truncate command can release the disk space. The disk space
+        will be released immediately and hence the model_output table won't grow
+        in size with each hop.
 
         Without this guc, the disk space won't be released and each
-        call to the UPDATE statement will keep adding to the disk space. The disk
-        space will only be released when the query is completed.
+        call to TRUNCATE or DROP will keep adding to the disk space. The
+        disk space will only be released when the query is completed.
 
         The guc can cause data loss if not used properly. Since truncate will
         actually clear the disk space immediately, there is no way to recover to
@@ -747,31 +987,10 @@ class FitMultipleModel():
         be set for intermediate tables and never for tables created outside the
         scope of the fit_multiple udf.
 
-        Workflow
-        1. Create temp table from model table (including the indexes)
-        2. truncate the model table to release disk space
-        3. rename temp table to model table so that it can be reused for the next
-        hop
-        :return:
         """
 
         with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
-            temp_model_table = unique_string('updated_model')
-            unlogged_table = self.unlogged_table if not self.is_final_training_call else ''
-            plpy.execute("""
-            CREATE {unlogged_table} TABLE {temp_model_table} ( LIKE {self.model_output_table}
-            INCLUDING indexes);""".format(temp_model_table=temp_model_table,
-                                          unlogged_table=unlogged_table,
-                                          self=self))
-            plpy.execute("""
-            INSERT INTO {temp_model_table} SELECT * FROM {self.model_output_table};
-            TRUNCATE TABLE {self.model_output_table};
-            DROP TABLE {self.model_output_table};
-            """.format(temp_model_table=temp_model_table, self=self))
-            rename_table(self.schema_madlib, temp_model_table,
-                         self.model_output_table)
             plpy.execute("""
-            TRUNCATE TABLE {self.mst_weights_tbl}, {self.mst_current_schedule_tbl},
-            {self.weights_to_update_tbl};
-            DROP TABLE IF EXISTS {self.mst_weights_tbl}, {self.mst_current_schedule_tbl},
-            {self.weights_to_update_tbl};""".format(self=self))
+                TRUNCATE TABLE {table};
+                DROP TABLE {table}
+            """.format(table=table))
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 8b6aef2..b0ac70b 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1420,23 +1420,25 @@ File madlib_keras_fit_multiple_model.sql_in documents training, evaluate and pre
 */
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
-    source_table            VARCHAR,
-    model_output_table      VARCHAR,
-    model_selection_table   VARCHAR,
-    num_iterations          INTEGER,
-    use_gpus                BOOLEAN,
-    validation_table        VARCHAR,
-    metrics_compute_frequency  INTEGER,
-    warm_start              BOOLEAN,
-    name                    VARCHAR,
-    description             VARCHAR,
-    use_caching             BOOLEAN DEFAULT FALSE
+    source_table                VARCHAR,
+    model_output_table          VARCHAR,
+    model_selection_table       VARCHAR,
+    num_iterations              INTEGER,
+    use_gpus                    BOOLEAN,
+    validation_table            VARCHAR,
+    metrics_compute_frequency   INTEGER,
+    warm_start                  BOOLEAN,
+    name                        VARCHAR,
+    description                 VARCHAR,
+    use_caching                 BOOLEAN DEFAULT FALSE
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     from utilities.control import SetGUC
     with AOControl(False):
         with SetGUC("plan_cache_mode", "force_generic_plan"):
-            fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
+            with MinWarning("warning"):
+                fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
+                fit_obj.fit_multiple_model()
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -1506,7 +1508,6 @@ $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
-    state                      BYTEA,
     dependent_var              BYTEA,
     independent_var            BYTEA,
     dependent_var_shape        INTEGER[],
@@ -1520,57 +1521,26 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
     accessible_gpus_for_seg    INTEGER[],
-    prev_serialized_weights    BYTEA,
+    serialized_weights         BYTEA,
     is_final_training_call     BOOLEAN,
     use_caching                BOOLEAN,
     custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    if use_caching:
-        return madlib_keras.fit_multiple_transition_caching(**globals())
-    else:
-        return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        if use_caching:
+            return madlib_keras.fit_multiple_transition_caching(**globals())
+        else:
+            return madlib_keras.fit_transition(state=None, prev_serialized_weights=serialized_weights,
+                                               is_multiple_model=True, **globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + '\nTransAggDetail:\n' + detail
+        e.args = (message,)
+        raise e
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
-
-DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step_multiple_model(
-    BYTEA,
-    BYTEA,
-    INTEGER[],
-    INTEGER[],
-    TEXT,
-    TEXT,
-    TEXT,
-    INTEGER,
-    INTEGER[],
-    INTEGER,
-    INTEGER,
-    INTEGER[],
-    BOOLEAN,
-    INTEGER[],
-    BYTEA,
-    BOOLEAN,
-    BOOLEAN,
-    BYTEA);
-CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
-    /* dependent_var */              BYTEA,
-    /* independent_var */            BYTEA,
-    /* dependent_var_shape */        INTEGER[],
-    /* independent_var_shape */      INTEGER[],
-    /* model_architecture */         TEXT,
-    /* compile_params */             TEXT,
-    /* fit_params */                 TEXT,
-    /* dist_key */                   INTEGER,
-    /* dist_key_mapping */           INTEGER[],
-    /* current_seg_id */             INTEGER,
-    /* segments_per_host */          INTEGER,
-    /* images_per_seg */             INTEGER[],
-    /* accessible_gpus_for_seg */    INTEGER[],
-    /* prev_serialized_weights */    BYTEA,
-    /* is_final_training_call */     BOOLEAN,
-    /* use_caching */                BOOLEAN,
-    /* custom_function_obj_map */    BYTEA
-)(
-    STYPE=BYTEA,
-    SFUNC=MADLIB_SCHEMA.fit_transition_multiple_model
-);
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index cf030e1..15f2493 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -139,7 +139,7 @@ def get_image_count_per_seg_for_minibatched_data_from_db(table_name):
                 FROM {1}
             """.format(shape_col, table_name))
         images_per_seg = [sum(r['shape'][0] for r in res)]
-        seg_ids = [0]
+        dist_keys = [0]
     else:
         # The number of images in the buffer is the first dimension in the shape.
         # Using __dist_key__ instead of gp_segment_id: Since gp_segment_id is
@@ -159,12 +159,12 @@ def get_image_count_per_seg_for_minibatched_data_from_db(table_name):
                 FROM {2}
                 GROUP BY {0}
             """.format(DISTRIBUTION_KEY_COLNAME, shape_col, table_name))
-        seg_ids = [int(each_segment[DISTRIBUTION_KEY_COLNAME])
+        dist_keys = [int(each_segment[DISTRIBUTION_KEY_COLNAME])
                    for each_segment in images_per_seg]
         images_per_seg = [int(each_segment["images_per_seg"])
                           for each_segment in images_per_seg]
 
-    return seg_ids, images_per_seg
+    return dist_keys, images_per_seg
 
 def get_image_count_per_seg_for_non_minibatched_data_from_db(table_name):
     """
@@ -235,6 +235,17 @@ def query_dist_keys(source_table, dist_key_col):
     res = [x[dist_key_col] for x in res]
     return res
 
+def query_weights(model_output_table, model_weights_col, mst_key_col, mst_key):
+    mlp_weights_query = """
+                        SELECT {model_weights_col}, {mst_key_col}
+                        FROM {model_output_table}
+                        WHERE {mst_key_col} = {mst_key}
+                        """.format(**locals())
+    res = plpy.execute(mlp_weights_query)
+    if not res:
+        plpy.error("query_weights:  No weights in model output table for mst={}".format(mst_key))
+    return res[0][model_weights_col]
+
 def create_summary_view(module_name, model_table, mst_key):
     tmp_view_summary = unique_string('tmp_view_summary')
     model_summary_table = add_postfix(model_table, "_summary")
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 62d3cf7..62b349e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -335,7 +335,6 @@ def internal_keras_predict(independent_var, model_architecture, model_weights,
         clear_keras_session()
         plpy.error(ex)
 
-
 def predict_help(schema_madlib, message, **kwargs):
     """
     Help function for keras predict
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
index 6fa210c..7d96887 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
@@ -17,6 +17,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import plpy
 from utilities.utilities import _assert
 
 # TODO
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index c60a19b..d7b2d41 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -66,7 +66,6 @@ def reset_cuda_env(value):
             del os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
 def get_device_name_and_set_cuda_env(gpu_count, seg):
-
     if gpu_count > 0:
         device_name = '/gpu:0'
         if is_platform_pg():
@@ -378,7 +377,7 @@ def query_custom_functions_map(object_table, custom_fn_names):
     Args:
         @param: object_table    Name of the object table
         @param: custom_fn_names List of custom function read from compile_param
-                                if custom function exisst in compile_params,
+                                if custom function exist in compile_params,
                                     expected list length >= 1
                                 else,
                                     an empty list is passed in
@@ -390,16 +389,17 @@ def query_custom_functions_map(object_table, custom_fn_names):
                                 {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
 
     """
+    # Dictionary map of name:object
+    # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
+    custom_fn_map = dict()
+
     if len(custom_fn_names) < 1:
-        return None
+        return custom_fn_map
 
     fn_set = set(custom_fn_names)
-    unique_fn_list = (list(fn_set))
+    unique_fn_list = list(fn_set)
 
-    custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
-    # Dictionary map of name:object
-    # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
-    custom_fn_map = defaultdict(list)
+    custom_obj_col_name = CustomFunctionSchema.FN_OBJ
     # Query the custom function if not yet loaded from table
     res = plpy.execute("""
                         SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table}
diff --git a/src/ports/postgres/modules/deep_learning/model_arch_info.py_in b/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
index 8f7418b..298f63a 100644
--- a/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
+++ b/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
@@ -85,8 +85,31 @@ def get_model_arch_layers_str(model_arch):
             layers += "{1}\n".format(class_name)
     return layers
 
-def get_model_arch_weights(model_arch_table, model_id):
+def get_model_arch(model_arch_table, model_id):
+    """
+    For fit_multiple, we don't want to keep sending weights back and
+    forth between the main host and the segment hosts.  weights can be
+    up to 1GB in size, whereas the model arch in JSON is usually very
+    small.
+    """
+    s = ModelArchSchema
+    model_arch_query = """
+        SELECT {s.MODEL_ARCH} FROM {model_arch_table}
+            WHERE {s.MODEL_ID} = {model_id}
+    """.format(**locals())
 
+    model_arch_result = plpy.execute(model_arch_query)
+    if not model_arch_result or len(model_arch_result) != 1:
+        plpy.error("no model arch found in table {0} with id {1}".format(
+            model_arch_table, model_id))
+
+    model_arch = model_arch_result[0][ModelArchSchema.MODEL_ARCH]
+    return model_arch
+
+def get_model_arch_weights(model_arch_table, model_id):
+    """
+    For fit, we need both the model arch & model weights
+    """
     #assume validation is already called
     model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
         ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS,
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index e2a8622..af3bdc0 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -110,44 +110,43 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_state.tostring(),  **kwargs)
+
+        image_count = kwargs['GD']['agg_image_count']
+        self.assertEqual(ending_image_count, image_count)
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
 
-    def _test_fit_transition_multiple_model_no_cache_first_buffer_pass(self,
-                                                                      **kwargs):
+    def _test_fit_transition_multiple_model_no_cache_first_buffer_pass(self, **kwargs):
         ending_image_count = len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
-
         new_state = self.subject.fit_transition(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(),
+            self.accessible_gpus_for_seg, self.serialized_weights,
              True, **kwargs)
 
-        image_count = new_state
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = kwargs['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
 
     def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
         ending_image_count = len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
-
         k = {'GD': {}}
         new_state = self.subject.fit_multiple_transition_caching(
-            None, self.dependent_var, self.independent_var,
+            self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
 
-        image_count = new_state
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = k['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue('cache_set' not in k['GD'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
@@ -162,7 +161,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **kwargs)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, **kwargs)
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
@@ -172,51 +171,56 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        state = starting_image_count
+        kwargs['GD']['agg_image_count'] = starting_image_count
+
         new_state = self.subject.fit_transition(
-            state, self.dependent_var, self.independent_var,
+            None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True,
             **kwargs)
 
-        image_count = new_state
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = kwargs['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
 
     def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'GD': {'x_train': x_train, 'y_train': y_train}}
+        k = {'GD': {'x_train': x_train, 'y_train': y_train,
+                    'agg_image_count' : starting_image_count
+                    }
+            }
 
-        state = starting_image_count
         new_state = self.subject.fit_multiple_transition_caching(
-            state, self.dependent_var, self.independent_var,
+            self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
+
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
 
-        image_count = new_state
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = k['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue('cache_set' not in k['GD'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
     def _test_fit_transition_last_buffer_pass(self, **kwargs):
-
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
+        kwargs['GD']['agg_image_count'] = starting_image_count
+
         state = starting_image_count
         previous_state = np.array(self.model_weights, dtype=np.float32)
         new_state = self.subject.fit_transition(
@@ -226,6 +230,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_state.tostring(),
             **kwargs)
+
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         # We need to assert that the weights should be multiplied by final image count.
@@ -287,7 +292,6 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-
         state = [self.loss * starting_image_count,
                  self.accuracy * starting_image_count, starting_image_count]
 
@@ -310,9 +314,8 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
                                                                      **kwargs):
         starting_image_count = 2*len(self.dependent_var_int)
 
-        state = starting_image_count
         new_state = self.subject.fit_transition(
-            state , self.dependent_var, self.independent_var,
+            None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
@@ -320,17 +323,15 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             True, **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
 
         ## image count should not be added to the final state of
         # fit multiple
-        self.assertEqual(len(self.model_weights), len(weights))
+        self.assertEqual(len(self.model_weights), len(state))
 
     def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
@@ -342,32 +343,34 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
 
         state = starting_image_count
         graph1 = self.subject.tf.get_default_graph()
+
+        k['GD']['agg_image_count'] = starting_image_count
+
         new_state = self.subject.fit_multiple_transition_caching(
-            state, self.dependent_var, self.independent_var,
+            self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
         state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
 
         ## image count should not be added to the final state of
         # fit multiple
-        self.assertEqual(len(self.model_weights), len(weights))
+        self.assertEqual(len(self.model_weights), len(state))
 
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue(k['GD']['cache_set'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
+        # TODO:  test is_final_training_call = True
+
     def test_fit_transition_multiple_model_cache_filled_pass(self):
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
@@ -380,19 +383,18 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         self.subject.compile_and_set_weights(self.model, self.compile_params,
                                                      '/cpu:0', self.serialized_weights)
         s1 = self.subject.K.get_session()
-        k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True,
+        k = {'GD': {'x_train': x_train, 'y_train': y_train,
                     'sess': s1, 'segment_model': self.model}}
         graph1 = self.subject.tf.get_default_graph()
         new_state = self.subject.fit_multiple_transition_caching(
-            None, self.dependent_var, self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
+            None, None,
+            None, None,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
-        state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
+        weights = np.fromstring(new_state, dtype=np.float32)
 
         ## image count should not be added to the final state of
         # fit multiple
@@ -400,7 +402,6 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
 
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue(k['GD']['cache_set'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
@@ -408,7 +409,6 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
@@ -418,19 +418,18 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+        k = {'GD': {'x_train': x_train, 'y_train': y_train }}
         graph1 = self.subject.tf.get_default_graph()
         new_state = self.subject.fit_multiple_transition_caching(
-            None, self.dependent_var, self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
+            None, None,
+            None, None,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
 
-        state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
+        weights = np.fromstring(new_state, dtype=np.float32)
 
         ## image count should not be added to the final state of
         # fit multiple
@@ -438,7 +437,6 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
 
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue('cache_set' not in k['GD'])
         self.assertTrue('x_train' not in k['GD'])
         self.assertTrue('y_train' not in k['GD'])
 
@@ -627,6 +625,14 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         self.assertTrue(iter_sess._closed)
         return iter_sess
 
+    def _init_GD(self, gd):
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+        gd = {'segment_model': self.model,
+                    'sess': Mock(),
+                    'agg_image_count' : starting_image_count
+        }
+
     def _assert_keras_session_same_as_gd_session(self, gd):
         sess = self.subject.K.get_session()
         self.assertEquals(sess, gd['sess'])
@@ -640,27 +646,6 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
 
     ################################################################
 
-    def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
-        k = {}
-        self.assertEqual('dummy_state',
-                         self.subject.fit_transition('dummy_state', [0], None,
-                                                     'noshape', 'noshape',
-                                                     'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], 0, 4, [3,3,3], False,
-                                                     [0], 'dummy_prev_state', **k))
-        self.assertEqual('dummy_state',
-                         self.subject.fit_transition('dummy_state', None, [[0.5]],
-                                                     'noshape', 'noshape',
-                                                     'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], 0, 4, [3,3,3], False,
-                                                     [0], 'dummy_prev_state', **k))
-        self.assertEqual('dummy_state',
-                         self.subject.fit_transition('dummy_state', None, None,
-                                                     'noshape', 'noshape',
-                                                     'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], 0, 4, [3,3,3], False,
-                                                     [0], 'dummy_prev_state', **k))
-
     def test_fit_merge(self):
         image_count = self.total_images_per_seg[0]
         state1 = [image_count]
@@ -750,7 +735,6 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         res = self.subject.should_compute_metrics_this_iter(2, 1, 5)
         self.assertEqual(True, res)
 
-
 class InternalKerasPredictTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
@@ -1016,7 +1000,7 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         target_dict = {'batch_size':2, 'epochs':1, 'verbose':0}
         literal_eval_fit_params = ['batch_size','epochs','verbose','shuffle',
                            'class_weight','initial_epoch','steps_per_epoch']
-        accepted_fit_params = literal_eval_fit_params + ['shuffle']
+        accepted_fit_params = literal_eval_fit_params
         result_params = self.subject.validate_and_literal_eval_keys(
                             test_dict,
                             literal_eval_fit_params,
@@ -1024,10 +1008,6 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         self.assertDictEqual(result_params, target_dict)
 
     def test_parse_and_validate_fit_params(self):
-        result = {'batch_size':2, 'epochs':1, 'verbose':0}
-        self.assertDictEqual(result, self.subject.parse_and_validate_fit_params('batch_size=2, epochs=1, verbose=0'))
-
-    def test_parse_and_validate_fit_params(self):
         test_str = "batch_size=2, epochs=1, verbose=0"
         fit_dict = {'batch_size':2, 'epochs':1, 'verbose':0}
         result_params = self.subject.parse_and_validate_fit_params(test_str)