You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2019/10/14 21:52:23 UTC

[madlib] branch master updated: DL: Add training for multiple models

This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 610cf6d  DL: Add training for multiple models
610cf6d is described below

commit 610cf6de306499e170d8e03dab85e1ee5559ccaa
Author: Yuhao Zhang <yz...@pivotal.io>
AuthorDate: Mon Oct 14 17:49:27 2019 -0400

    DL: Add training for multiple models
    
    This commit adds a new function to train multiple models in parallel
    with model hopping method supported for Greenplum DB only.
    
    Model hopping method involves the following steps:
    
    - Train models in parallel for a single epoch using the data local
    to each segment.
    - Move the models to the next segment in round-robin fashion.
    
    This method ensures that all of the models visit the entire dataset,
    which eliminates the need to average the model at the end.
    
    This commit also fixes the following issue:
    -  In the regular fit function, excessive amounts of threads were being
    created and left over by keras sessions. This issue was fixed by reusing
    the same session and the same computational graph throughout the
    process.
    - While deserializing weights, if the model shape expected less elements
    than present in the weights, the excessive weights would get dropped.
    This is fixed by adding an explicit check for validating number of
    elements in model weights matches the model.
    
    Closes #443
    
    Co-authored-by: Ekta Khanna <ek...@pivotal.io>
    Co-authored-by: Nandish Jayaram <nj...@apache.org>
    Co-authored-by: Nikhil Kak <nk...@pivotal.io>
    Co-authored-by: Orhan Kislal <ok...@apache.org>
---
 .../modules/deep_learning/madlib_keras.py_in       | 189 ++++++---
 .../modules/deep_learning/madlib_keras.sql_in      |  18 +-
 .../madlib_keras_fit_multiple_model.py_in          | 454 +++++++++++++++++++++
 .../madlib_keras_fit_multiple_model.sql_in         | 112 +++++
 .../deep_learning/madlib_keras_helper.py_in        |  39 +-
 .../madlib_keras_model_selection.py_in             |  73 ++--
 .../deep_learning/madlib_keras_predict.py_in       |   3 +-
 .../deep_learning/madlib_keras_serializer.py_in    |  13 +-
 .../deep_learning/madlib_keras_validator.py_in     |  70 +++-
 .../deep_learning/madlib_keras_wrapper.py_in       |  47 ++-
 .../test/madlib_keras_model_selection.sql_in       | 216 ++++++++--
 .../deep_learning/test/madlib_keras_predict.sql_in |  15 +-
 .../test/unit_tests/test_madlib_keras.py_in        | 404 +++++++++++++++---
 .../postgres/modules/utilities/utilities.py_in     |  14 +
 14 files changed, 1442 insertions(+), 225 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index b775af3..96fd3fe 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -40,6 +40,38 @@ from utilities.utilities import madlib_version
 from utilities.validate_args import get_expr_type
 from utilities.validate_args import quote_ident
 from utilities.control import MinWarning
+import tensorflow as tf
+
+class SD_STORE:
+    SESS = 'sess'
+    SEGMENT_MODEL = 'segment_model'
+
+    @staticmethod
+    def init_SD(SD, sess, segment_model):
+        SD[SD_STORE.SESS] = sess
+        SD[SD_STORE.SEGMENT_MODEL] = segment_model
+
+    @staticmethod
+    def clear_SD(SD):
+        del SD[SD_STORE.SEGMENT_MODEL]
+        del SD[SD_STORE.SESS]
+
+def get_init_model_and_sess(SD, device_name, gpus_per_host, segments_per_host,
+                               model_architecture, compile_params):
+    # If a live session is present, re-use it. Otherwise, recreate it.
+    if SD_STORE.SESS in SD :
+        if SD_STORE.SEGMENT_MODEL not in SD:
+            plpy.error("Session and model should exist in SD after the first row"
+                       "of the first iteration")
+        sess = SD[SD_STORE.SESS]
+        segment_model = SD[SD_STORE.SEGMENT_MODEL]
+        K.set_session(sess)
+    else:
+        sess = get_keras_session(device_name, gpus_per_host, segments_per_host)
+        K.set_session(sess)
+        segment_model = init_model(model_architecture, compile_params)
+        SD_STORE.init_SD(SD, sess, segment_model)
+    return segment_model, sess
 
 @MinWarning("warning")
 def fit(schema_madlib, source_table, model, model_arch_table,
@@ -117,10 +149,11 @@ def fit(schema_madlib, source_table, model, model_arch_table,
             ARRAY{images_per_seg_train},
             {gpus_per_host},
             {segments_per_host},
-            $1
+            $1,
+            $2
         ) AS iteration_result
         FROM {source_table}
-        """.format(**locals()), ["bytea"])
+        """.format(**locals()), ["bytea", "boolean"])
 
     # Define the state for the model and loss/metric storage lists
     training_loss, training_metrics, metrics_elapsed_time = [], [], []
@@ -132,8 +165,9 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     # Run distributed training for specified number of iterations
     for i in range(1, num_iterations+1):
         start_iteration = time.time()
+        is_final_iteration = (i == num_iterations)
         iteration_result = plpy.execute(run_training_iteration,
-                                        [serialized_weights])[0]['iteration_result']
+                                        [serialized_weights, is_final_iteration])[0]['iteration_result']
         end_iteration = time.time()
         info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
             end_iteration - start_iteration)
@@ -146,7 +180,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
             compute_out = compute_loss_and_metrics(
                 schema_madlib, source_table, compile_params_to_pass, model_arch,
                 serialized_weights, gpus_per_host, segments_per_host, seg_ids_train,
-                images_per_seg_train, training_metrics, training_loss, i)
+                images_per_seg_train, training_metrics, training_loss, i, is_final_iteration)
             metrics_iters.append(i)
             compute_time, compute_metrics, compute_loss = compute_out
 
@@ -163,7 +197,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
                     schema_madlib, validation_table, compile_params_to_pass,
                     model_arch, serialized_weights, gpus_per_host, segments_per_host,
                     seg_ids_val, images_per_seg_val, validation_metrics,
-                    validation_loss, i)
+                    validation_loss, i, is_final_iteration)
                 val_compute_time, val_compute_metrics, val_compute_loss = val_compute_out
 
                 info_str += "\n\tTime for evaluating validation dataset in "\
@@ -344,7 +378,7 @@ def get_metrics_sql_string(metrics_list, is_metrics_specified):
 def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
                              serialized_weights, gpus_per_host, segments_per_host,
                              seg_ids, images_per_seg_val, metrics_list, loss_list,
-                             curr_iter):
+                             curr_iter, is_final_iteration):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
     given dataset (table.)
@@ -358,7 +392,8 @@ def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
                                                    gpus_per_host,
                                                    segments_per_host,
                                                    seg_ids,
-                                                   images_per_seg_val)
+                                                   images_per_seg_val,
+                                                   is_final_iteration)
     end_val = time.time()
 
     if len(evaluate_result) not in [1, 2]:
@@ -387,63 +422,86 @@ def should_compute_metrics_this_iter(curr_iter, metrics_compute_frequency,
     return (curr_iter)%metrics_compute_frequency == 0 or \
            curr_iter == num_iterations
 
+def init_model(model_architecture, compile_params):
+    """
+        Should only be called at the first row of first iteration.
+    """
+    segment_model = model_from_json(model_architecture)
+    compile_model(segment_model, compile_params)
+    return segment_model
+
+def update_model(segment_model, prev_serialized_weights):
+    """
+        Happens at first row of each iteration.
+    """
+    model_shapes = get_model_shapes(segment_model)
+    model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
+        prev_serialized_weights, model_shapes)
+    segment_model.set_weights(model_weights)
+
 def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                    independent_var_shape, model_architecture,
                    compile_params, fit_params, current_seg_id, seg_ids,
                    images_per_seg, gpus_per_host, segments_per_host,
-                   prev_serialized_weights, **kwargs):
+                   prev_serialized_weights, is_final_iteration=True,
+                   is_multiple_model=False, **kwargs):
+    """
+    This transition function is common for madlib_keras_fit() and
+    madlib_keras_fit_multiple_model(). The important difference between
+    these two calls is the way this function handles the input param
+    prev_serialized_weights and clearing keras session.
+    For madlib_keras_fit_multiple_model,
+        a. prev_serialized_weights is always passed in as the state
+        (image count, serialized weights), since it is fetched in the
+        table for each hop of the model between segments.
+        b. keras session is cleared at the end of each iteration, i.e,
+        last row of each iteration.
+    For madlib_keras_fit,
+        a. prev_serialized_weights is passed in as serialized weights
+        b. keras session is cleared at the end of the final iteration,
+        i.e, last row of last iteration.
+    """
     if not independent_var or not dependent_var:
         return state
-
-    start_transition = time.time()
     SD = kwargs['SD']
     device_name = get_device_name_and_set_cuda_env(gpus_per_host,
                                                    current_seg_id)
-    # Set up system if this is the first buffer on segment'
+    if is_multiple_model:
+        prev_serialized_weights = madlib_keras_serializer.\
+            get_serialized_1d_weights_from_state(prev_serialized_weights)
+    segment_model, sess = get_init_model_and_sess(SD, device_name,
+                                                  gpus_per_host, segments_per_host,
+                                                  model_architecture, compile_params)
+    agg_image_count = madlib_keras_serializer.get_image_count_from_state(state)
     if not state:
-        set_keras_session(device_name, gpus_per_host, segments_per_host)
-        segment_model = model_from_json(model_architecture)
-        compile_and_set_weights(segment_model, compile_params, device_name,
-                                prev_serialized_weights)
-
-        SD['segment_model'] = segment_model
-        agg_image_count = 0
-    else:
-        segment_model = SD['segment_model']
-        agg_image_count = madlib_keras_serializer.get_image_count_from_state(state)
+        set_model_weights(segment_model, prev_serialized_weights)
 
     # Prepare the data
     x_train = np_array_float32(independent_var, independent_var_shape)
     y_train = np_array_int16(dependent_var, dependent_var_shape)
 
     # Fit segment model on data
-    start_fit = time.time()
-    with K.tf.device(device_name):
-        #TODO consider not doing this every time
-        fit_params = parse_and_validate_fit_params(fit_params)
-        history = segment_model.fit(x_train, y_train, **fit_params)
-    end_fit = time.time()
-
+    #TODO consider not doing this every time
+    fit_params = parse_and_validate_fit_params(fit_params)
+    history = segment_model.fit(x_train, y_train, **fit_params)
     image_count = len(x_train)
+
     # Aggregating number of images, loss and accuracy
     agg_image_count += image_count
-
-    with K.tf.device(device_name):
-        updated_weights = segment_model.get_weights()
-
+    updated_weights = segment_model.get_weights()
     total_images = get_image_count_per_seg_from_array(current_seg_id, seg_ids,
                                                       images_per_seg)
 
-    # Re-serialize the weights
-    # Update image count, check if we are done
     if agg_image_count == total_images:
-        # Once done with all images on a segment, we update weights
-        # with the total number of images here instead of the merge function.
-        # The merge function only deals with aggregating them.
-        updated_weights = [ total_images * w for w in updated_weights ]
-            # In GPDB, each segment would have a keras session, so clear
-            # them after the last buffer is processed.
-        clear_keras_session()
+        # For madlib_keras_fit_multiple_model(), we don't need to update weights
+        # with the total no of images as there is no merge function for it.
+        if not is_multiple_model:
+            updated_weights = [total_images * w for w in updated_weights]
+        if is_final_iteration or is_multiple_model:
+            SD_STORE.clear_SD(SD)
+            clear_keras_session(sess)
+            del segment_model
+            del sess
 
     new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
         agg_image_count, updated_weights)
@@ -451,8 +509,6 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     del x_train
     del y_train
 
-    end_transition = time.time()
-
     return new_state
 
 def fit_merge(state1, state2, **kwargs):
@@ -563,7 +619,8 @@ def validate_evaluate(module_name, model_table, model_summary_table, test_table,
 
 def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                     model_arch, serialized_weights, gpus_per_host,
-                                    segments_per_host, seg_ids, images_per_seg):
+                                    segments_per_host, seg_ids, images_per_seg,
+                                    is_final_iteration=True):
 
     gp_segment_id_col = '0' if is_platform_pg() else 'gp_segment_id'
 
@@ -591,7 +648,8 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                             ARRAY{seg_ids},
                                             ARRAY{images_per_seg},
                                             {gpus_per_host},
-                                            {segments_per_host}
+                                            {segments_per_host},
+                                            {is_final_iteration}
                                             )) as loss_metric
         from {table}
     """.format(**locals()), ["bytea"])
@@ -599,35 +657,42 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
     loss_metric = res[0]['loss_metric']
     return loss_metric
 
+
 def internal_keras_eval_transition(state, dependent_var, independent_var,
                                    dependent_var_shape, independent_var_shape,
                                    model_architecture, serialized_weights, compile_params,
                                    current_seg_id, seg_ids, images_per_seg,
-                                   gpus_per_host, segments_per_host, **kwargs):
+                                   gpus_per_host, segments_per_host,
+                                   is_final_iteration, **kwargs):
     SD = kwargs['SD']
     device_name = get_device_name_and_set_cuda_env(gpus_per_host, current_seg_id)
-
     agg_loss, agg_metric, agg_image_count = state
 
+    # This transition function is common to evaluate as well as the fit functions
+    # and is used to determine when to clear the session.
+    # For evaluate,
+    #   is_final_iteration is always set to true, so the session is cleared once
+    #   evaluated the last buffer on each segment.
+    # When called from fit functions,
+    #  if is_final_iteration is false, the fit function has already created a
+    #   session and a graph that can be used between iterations and cleared only
+    #   for the last buffer of last iteration
+    #  if is_final_iteration is false, we can clear the
+
+    segment_model, sess = get_init_model_and_sess(SD, device_name, gpus_per_host,
+                                                  segments_per_host, model_architecture,
+                                                  compile_params)
     if not agg_image_count:
-        set_keras_session(device_name, gpus_per_host, segments_per_host)
-        model = model_from_json(model_architecture)
-        compile_and_set_weights(model, compile_params, device_name,
-                                serialized_weights)
-
-        SD['segment_model'] = model
         # These should already be 0, but just in case make sure
         agg_metric = 0
         agg_loss = 0
-    else:
-        # Same model every time, no need to re-compile or update weights
-        model = SD['segment_model']
+        set_model_weights(segment_model, serialized_weights)
 
     x_val = np_array_float32(independent_var, independent_var_shape)
     y_val = np_array_int16(dependent_var, dependent_var_shape)
 
     with K.tf.device(device_name):
-        res = model.evaluate(x_val, y_val)
+        res = segment_model.evaluate(x_val, y_val)
 
     # if metric is None, model.evaluate will only return loss as a scalar
     # Otherwise, it will return a list which has loss and metric
@@ -646,9 +711,12 @@ def internal_keras_eval_transition(state, dependent_var, independent_var,
     total_images = get_image_count_per_seg_from_array(current_seg_id, seg_ids,
                                                       images_per_seg)
 
-    if agg_image_count == total_images:
-        SD.pop('segment_model', None)
-        clear_keras_session()
+    if agg_image_count == total_images and is_final_iteration:
+        K.clear_session()
+        sess.close()
+        SD_STORE.clear_SD(SD)
+        del segment_model
+        del sess
 
     state[0] = agg_loss
     state[1] = agg_metric
@@ -684,7 +752,6 @@ def internal_keras_eval_final(state, **kwargs):
     return loss, metric
 
 
-
 def fit_help(schema_madlib, message, **kwargs):
     """
     Help function for keras fit
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index ccba02d..cf4f2d1 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1718,7 +1718,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     images_per_seg             INTEGER[],
     gpus_per_host              INTEGER,
     segments_per_host          INTEGER,
-    prev_serialized_weights    BYTEA
+    prev_serialized_weights    BYTEA,
+    is_final_iteration         BOOLEAN
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_transition(**globals())
@@ -1755,7 +1756,8 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step(
     INTEGER[],
     INTEGER,
     INTEGER,
-    BYTEA);
+    BYTEA,
+    BOOLEAN);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* dep_var */                BYTEA,
     /* ind_var */                BYTEA,
@@ -1769,7 +1771,8 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* images_per_seg*/          INTEGER[],
     /* gpus_per_host  */         INTEGER,
     /* segments_per_host  */     INTEGER,
-    /* serialized_weights */     BYTEA
+    /* serialized_weights */     BYTEA,
+    /* is_final_iteration */     BOOLEAN
 )(
     STYPE=BYTEA,
     SFUNC=MADLIB_SCHEMA.fit_transition,
@@ -1947,7 +1950,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_eval_transition(
     seg_ids                            INTEGER[],
     images_per_seg                     INTEGER[],
     gpus_per_host                      INTEGER,
-    segments_per_host                  INTEGER
+    segments_per_host                  INTEGER,
+    is_final_iteration                 BOOLEAN
 ) RETURNS REAL[3] AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.internal_keras_eval_transition(**globals())
@@ -1983,7 +1987,8 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.internal_keras_evaluate(
                                        INTEGER[],
                                        INTEGER[],
                                        INTEGER,
-                                       INTEGER);
+                                       INTEGER,
+                                       BOOLEAN);
 
 CREATE AGGREGATE MADLIB_SCHEMA.internal_keras_evaluate(
     /* dependent_var */                BYTEA,
@@ -1997,7 +2002,8 @@ CREATE AGGREGATE MADLIB_SCHEMA.internal_keras_evaluate(
     /* seg_ids */                      INTEGER[],
     /* images_per_seg*/                INTEGER[],
     /* gpus_per_host */                INTEGER,
-    /* segments_per_host */            INTEGER
+    /* segments_per_host */            INTEGER,
+    /* is_final_iteration */           BOOLEAN
 )(
     STYPE=REAL[3],
     INITCOND='{0,0,0}',
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
new file mode 100644
index 0000000..7569d39
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -0,0 +1,454 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+import time
+import sys
+
+from keras.models import *
+from madlib_keras import compute_loss_and_metrics
+from madlib_keras import get_initial_weights
+from madlib_keras import get_model_arch_weights
+from madlib_keras import get_segments_and_gpus
+from madlib_keras import get_source_summary_table_dict
+from madlib_keras_helper import *
+from madlib_keras_model_selection import ModelSelectionSchema
+from madlib_keras_validator import *
+from madlib_keras_wrapper import *
+
+from utilities.control import MinWarning
+from utilities.control import OptimizerControl
+from utilities.utilities import unique_string
+from utilities.utilities import add_postfix
+from utilities.utilities import rotate
+from utilities.utilities import madlib_version
+from utilities.utilities import is_platform_pg
+import json
+from collections import defaultdict
+import random
+import datetime
+mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+dist_key_col = DISTRIBUTION_KEY_COLNAME
+
+
+@MinWarning("warning")
+class FitMultipleModel():
+    def __init__(self, schema_madlib, source_table, model_output_table,
+                 model_selection_table, num_iterations,
+                 gpus_per_host=0, validation_table=None, **kwargs):
+        # set the random seed for visit order/scheduling
+        random.seed(1)
+        if is_platform_pg():
+            plpy.error(
+                "DL: Multiple model training is not supported on PostgreSQL.")
+        self.source_table = source_table
+        self.validation_table = validation_table
+        self.model_selection_table = model_selection_table
+        if self.model_selection_table:
+            self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')
+        self.model_output_table = model_output_table
+        if self.model_output_table:
+            self.model_info_table = add_postfix(model_output_table, '_info')
+            self.model_summary_table = add_postfix(
+                model_output_table, '_summary')
+        self.num_iterations = num_iterations
+        self.module_name = 'madlib_keras_fit_multiple_model'
+        self.schema_madlib = schema_madlib
+        self.version = madlib_version(self.schema_madlib)
+        self.mst_key_col = ModelSelectionSchema.MST_KEY
+        self.model_id_col = ModelSelectionSchema.MODEL_ID
+        self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS
+        self.fit_params_col = ModelSelectionSchema.FIT_PARAMS
+        self.model_arch_table_col = ModelSelectionSchema.MODEL_ARCH_TABLE
+        self.model_weights_col=ModelArchSchema.MODEL_WEIGHTS
+        self.model_arch_col=ModelArchSchema.MODEL_ARCH
+        self.train_mst_metric_eval_time = defaultdict(list)
+        self.train_mst_loss = defaultdict(list)
+        self.train_mst_metric = defaultdict(list)
+        self.info_str = ""
+        self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
+        self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
+        self.fit_validator_train = FitMultipleInputValidator(
+            self.source_table, self.validation_table, self.model_output_table,
+            self.model_selection_table, self.model_selection_summary_table,
+            mb_dep_var_col, mb_indep_var_col, self.num_iterations,
+            self.model_info_table, self.mst_key_col, self.model_arch_table_col,
+            1, False)
+        self.msts = self.fit_validator_train.msts
+        self.model_arch_table = self.fit_validator_train.model_arch_table
+        self.seg_ids_train, self.images_per_seg_train = \
+            get_image_count_per_seg_for_minibatched_data_from_db(
+                self.source_table)
+
+        if self.validation_table:
+            self.valid_mst_metric_eval_time = defaultdict(list)
+            self.valid_mst_loss = defaultdict(list)
+            self.valid_mst_metric = defaultdict(list)
+            self.seg_ids_valid, self.images_per_seg_valid = \
+                get_image_count_per_seg_for_minibatched_data_from_db(
+                    self.validation_table)
+        self.mst_weights_tbl = unique_string(desp='mst_weights')
+        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
+
+        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
+        if len(self.msts) < len(self.dist_keys):
+            self.msts_for_schedule = self.msts + [None] * \
+                (len(self.dist_keys) - len(self.msts))
+        else:
+            self.msts_for_schedule = self.msts
+        random.shuffle(self.msts_for_schedule)
+        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
+        self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
+            gpus_per_host)
+        self.create_model_output_table()
+        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
+        self.fit_multiple_model()
+
+    def fit_multiple_model(self):
+        # WARNING: set orca off to prevent unwanted redistribution
+        with OptimizerControl(False):
+            original_cuda_env = None
+            if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+                original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
+            self.start_training_time = datetime.datetime.now()
+            self.train_multiple_model()
+            self.end_training_time = datetime.datetime.now()
+            self.insert_info_table()
+            self.create_model_summary_table()
+        reset_cuda_env(original_cuda_env)
+
+    def train_multiple_model(self):
+        total_msts = len(self.msts_for_schedule)
+        for iter in range(self.num_iterations):
+            for mst_idx in range(total_msts):
+                mst_row = [self.grand_schedule[dist_key][mst_idx]
+                           for dist_key in self.dist_keys]
+                self.create_mst_schedule_table(mst_row)
+                if mst_idx == 0:
+                    start_iteration = time.time()
+                self.run_training()
+                if mst_idx == (total_msts - 1):
+                    end_iteration = time.time()
+                    self.info_str = "\tTime for training in iteration {0}: {1} sec\n".format(iter,
+                                                                                      end_iteration - start_iteration)
+            self.info_str += "\tTraining set after iteration {0}:".format(iter)
+            self.evaluate_model(iter, self.source_table, True)
+            if self.validation_table:
+                self.evaluate_model(iter, self.validation_table, False)
+            plpy.info("\n"+self.info_str)
+
+    def evaluate_model(self, epoch, table, is_train):
+        if is_train:
+            mst_metric_eval_time = self.train_mst_metric_eval_time
+            mst_loss = self.train_mst_loss
+            mst_metric = self.train_mst_metric
+            seg_ids = self.seg_ids_train
+            images_per_seg = self.images_per_seg_train
+        else:
+            mst_metric_eval_time = self.valid_mst_metric_eval_time
+            mst_loss = self.valid_mst_loss
+            mst_metric = self.valid_mst_metric
+            seg_ids = self.seg_ids_valid
+            images_per_seg = self.images_per_seg_valid
+        for mst in self.msts:
+            state = query_weights(self.model_output_table, self.model_weights_col,
+                self.mst_key_col, mst[self.mst_key_col])
+            model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
+            serialized_weights = \
+                madlib_keras_serializer.get_serialized_1d_weights_from_state(
+                    state)
+            metric_eval_time, metric, loss = compute_loss_and_metrics(
+                self.schema_madlib, table, "$madlib${0}$madlib$".format(
+                    mst[self.compile_params_col]),
+                model_arch,
+                serialized_weights,
+                self.gpus_per_host,
+                self.segments_per_host,
+                seg_ids,
+                images_per_seg, [], [], epoch, True)
+            mst_metric_eval_time[mst[self.mst_key_col]] \
+                .append(metric_eval_time)
+            mst_loss[mst[self.mst_key_col]].append(loss)
+            mst_metric[mst[self.mst_key_col]].append(metric)
+            self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(mst[self.mst_key_col], metric, loss)
+
+    def generate_schedule(self, msts):
+        """ Generate the schedule for models hopping to segments """
+        grand_schedule = {}
+        for index, dist_key in enumerate(self.dist_keys):
+            grand_schedule[dist_key] = rotate(msts, index)
+        return grand_schedule
+
+    def create_mst_schedule_table(self, mst_row):
+        mst_temp_query = """
+                         CREATE TEMP TABLE {self.mst_current_schedule_tbl}
+                                ({self.model_id_col} INTEGER,
+                                 {self.compile_params_col} VARCHAR,
+                                 {self.fit_params_col} VARCHAR,
+                                 {dist_key_col} INTEGER,
+                                 {self.mst_key_col} INTEGER)
+                         """.format(dist_key_col=dist_key_col, **locals())
+        plpy.execute(mst_temp_query)
+        for mst, dist_key in zip(mst_row, self.dist_keys):
+            if mst:
+                model_id = mst[self.model_id_col]
+                compile_params = mst[self.compile_params_col]
+                fit_params = mst[self.fit_params_col]
+                mst_key = mst[self.mst_key_col]
+            else:
+                model_id = "NULL"
+                compile_params = "NULL"
+                fit_params = "NULL"
+                mst_key = "NULL"
+            mst_insert_query = """
+                               INSERT INTO {self.mst_current_schedule_tbl}
+                                   VALUES ({model_id},
+                                           $madlib${compile_params}$madlib$,
+                                           $madlib${fit_params}$madlib$,
+                                           {dist_key},
+                                           {mst_key})
+                                """.format(**locals())
+            plpy.execute(mst_insert_query)
+
+    def create_model_output_table(self):
+        output_table_create_query = """
+                                    CREATE TABLE {self.model_output_table}
+                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
+                                     {self.model_weights_col} BYTEA,
+                                     {self.model_arch_col} JSON)
+                                    """.format(self=self)
+        info_table_create_query = """
+                                  CREATE TABLE {self.model_info_table}
+                                  ({self.mst_key_col} INTEGER PRIMARY KEY,
+                                   {self.model_id_col} INTEGER,
+                                   {self.compile_params_col} TEXT,
+                                   {self.fit_params_col} TEXT,
+                                   model_type TEXT,
+                                   model_size DOUBLE PRECISION,
+                                   metrics_elapsed_time DOUBLE PRECISION[],
+                                   metrics_type TEXT[],
+                                   training_metrics_final DOUBLE PRECISION,
+                                   training_loss_final DOUBLE PRECISION,
+                                   training_metrics DOUBLE PRECISION[],
+                                   training_loss DOUBLE PRECISION[],
+                                   validation_metrics_final DOUBLE PRECISION,
+                                   validation_loss_final DOUBLE PRECISION,
+                                   validation_metrics DOUBLE PRECISION[],
+                                   validation_loss DOUBLE PRECISION[])
+                               """.format(self=self)
+
+        plpy.execute(output_table_create_query)
+        plpy.execute(info_table_create_query)
+        for mst in self.msts:
+            model_arch, model_weights = get_model_arch_weights(self.model_arch_table,
+                                                               mst[self.model_id_col])
+            serialized_weights = get_initial_weights(self.model_output_table,
+                                                     model_arch,
+                                                     model_weights,
+                                                     False,
+                                                     self.gpus_per_host
+                                                     )
+            model = model_from_json(model_arch)
+
+            serialized_state = \
+                madlib_keras_serializer.serialize_state_with_nd_weights(
+                    0, model.get_weights())
+            # serialized_weights = madlib_keras_serializer.serialize_nd_weights(
+            #     model.get_weights())
+            model_size = sys.getsizeof(serialized_weights) / 1024.0
+            metrics_list = get_metrics_from_compile_param(
+                mst[self.compile_params_col])
+            is_metrics_specified = True if metrics_list else False
+            metrics_type = 'ARRAY{0}'.format(
+                metrics_list) if is_metrics_specified else 'NULL'
+
+            output_table_insert_query = """
+                                INSERT INTO {self.model_output_table}(
+                                    {self.mst_key_col}, {self.model_weights_col},
+                                    {self.model_arch_col})
+                                VALUES ({mst_key}, $1, $2)
+                                   """.format(self=self,
+                                              mst_key=mst[self.mst_key_col])
+            output_table_insert_query_prepared = plpy.prepare(
+                output_table_insert_query, ["bytea", "json"])
+            plpy.execute(output_table_insert_query_prepared, [
+                         serialized_state, json.dumps(model_arch)])
+            info_table_insert_query = """
+                    INSERT INTO {self.model_info_table}({self.mst_key_col},
+                                {self.model_id_col}, {self.compile_params_col},
+                                {self.fit_params_col}, model_type, model_size,
+                                metrics_type)
+                        VALUES ({mst_key_val}, {model_id},
+                                $madlib${compile_params}$madlib$,
+                                $madlib${fit_params}$madlib$, '{model_type}',
+                                {model_size}, {metrics_type})
+                """.format(self=self,
+                           mst_key_val=mst[self.mst_key_col],
+                           model_id=mst[self.model_id_col],
+                           compile_params=mst[self.compile_params_col],
+                           fit_params=mst[self.fit_params_col],
+                           model_type='madlib_keras',
+                           model_size=model_size,
+                           metrics_type=metrics_type)
+            plpy.execute(info_table_insert_query)
+
+    def create_model_summary_table(self):
+        src_summary_dict = get_source_summary_table_dict(self.fit_validator_train)
+        class_values = src_summary_dict['class_values']
+        dep_vartype = src_summary_dict['dep_vartype']
+        dependent_varname = \
+            src_summary_dict['dependent_varname_in_source_table']
+        independent_varname = \
+            src_summary_dict['independent_varname_in_source_table']
+        norm_const = src_summary_dict['norm_const']
+        num_classes = len(class_values)
+        class_values_colname = CLASS_VALUES_COLNAME
+        dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
+        normalizing_const_colname = NORMALIZING_CONST_COLNAME
+        float32_sql_type = FLOAT32_SQL_TYPE
+        update_query = """
+                CREATE TABLE {self.model_summary_table} AS
+                SELECT
+                    $MAD${self.source_table}$MAD$::TEXT AS source_table,
+                    $MAD${self.validation_table}$MAD$::TEXT AS validation_table,
+                    $MAD${self.model_output_table}$MAD$::TEXT AS model,
+                    $MAD${self.model_info_table}$MAD$::TEXT AS model_info,
+                    $MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
+                    $MAD${independent_varname}$MAD$::TEXT AS independent_varname,
+                    $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
+                    {self.num_iterations}::INTEGER AS num_iterations,
+                    '{self.start_training_time}'::TIMESTAMP AS start_training_time,
+                    '{self.end_training_time}'::TIMESTAMP AS end_training_time,
+                    '{self.version}'::TEXT AS madlib_version,
+                    {num_classes}::INTEGER AS num_classes,
+                    ARRAY{class_values}::TEXT[] AS {class_values_colname},
+                    $MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname},
+                    {norm_const}::{float32_sql_type} AS {normalizing_const_colname}
+            """.format(**locals())
+        plpy.execute(update_query)
+
+    def update_info_table(self, mst, is_train):
+        mst_key = mst[self.mst_key_col]
+        metrics, metrics_final, metrics_elapsed_time = \
+            "NULL", "NULL", "NULL"
+        if is_train:
+            mst_metric = self.train_mst_metric
+            mst_metric_eval_time = self.train_mst_metric_eval_time
+            mst_loss = self.train_mst_loss
+        else:
+            mst_metric = self.valid_mst_metric
+            mst_metric_eval_time = self.valid_mst_metric_eval_time
+            mst_loss = self.valid_mst_loss
+
+        if mst_key in mst_metric:
+            metrics = mst_metric[mst_key]
+            metrics_final = metrics[-1]
+            metrics_elapsed_time = mst_metric_eval_time[mst_key]
+            metrics = "ARRAY{}".format(metrics)
+            metrics_elapsed_time = "ARRAY{}".format(metrics_elapsed_time)
+        loss = mst_loss[mst_key]
+        loss_final = loss[-1]
+        loss = "ARRAY{}".format(loss)
+        if is_train:
+            update_query = """
+                           UPDATE {self.model_info_table} SET
+                           training_metrics_final = {metrics_final},
+                           training_loss_final = {loss_final},
+                           metrics_elapsed_time = {metrics_elapsed_time},
+                           training_metrics = {metrics},
+                           training_loss = {loss}
+                           WHERE {self.mst_key_col} = {mst_key}
+                           """.format(**locals())
+        else:
+            update_query = """
+                           UPDATE {self.model_info_table} SET
+                           validation_metrics_final = {metrics_final},
+                           validation_loss_final = {loss_final},
+                           metrics_elapsed_time = {metrics_elapsed_time},
+                           validation_metrics = {metrics},
+                           validation_loss = {loss}
+                           WHERE {self.mst_key_col} = {mst_key}
+                           """.format(**locals())
+        plpy.execute(update_query)
+
+    def insert_info_table(self):
+        for mst in self.msts:
+            self.update_info_table(mst, True)
+            if self.validation_table:
+                self.update_info_table(mst, False)
+
+    def run_training(self):
+        mst_weights_query = """
+            CREATE TEMP TABLE {self.mst_weights_tbl} AS
+                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
+                       model_arch_tbl.{self.model_arch_col}
+                FROM
+                    {self.mst_current_schedule_tbl} mst_tbl
+                    LEFT JOIN {self.model_output_table} wgh_tbl
+                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
+                        LEFT JOIN {self.model_arch_table} model_arch_tbl
+                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
+                DISTRIBUTED BY ({dist_key_col})
+        """.format(dist_key_col=dist_key_col,
+                   **locals())
+        plpy.execute(mst_weights_query)
+        uda_query = """
+            CREATE TABLE {self.weights_to_update_tbl} AS
+            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
+                                                      {mb_indep_var_col},
+                                                      {self.dep_shape_col},
+                                                      {self.ind_shape_col},
+                                                      {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
+                                                      {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
+                                                      {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
+                                                      src.gp_segment_id,
+                                                      ARRAY{self.seg_ids_train},
+                                                      ARRAY{self.images_per_seg_train},
+                                                      {self.gpus_per_host},
+                                                      {self.segments_per_host},
+                                                      {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
+                                                      {is_final_iteration}::BOOLEAN
+                                                      )::BYTEA AS {self.model_weights_col},
+                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
+                ,src.{dist_key_col} AS {dist_key_col}
+                FROM {self.source_table} src JOIN {self.mst_weights_tbl}
+                    USING ({dist_key_col})
+                WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
+                GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
+            DISTRIBUTED BY({dist_key_col})
+            """.format(mb_dep_var_col=mb_dep_var_col,
+                       mb_indep_var_col=mb_indep_var_col,
+                       is_final_iteration=True,
+                       dist_key_col=dist_key_col,
+                       self=self
+                       )
+        plpy.execute(uda_query)
+
+        update_query = """
+            UPDATE {self.model_output_table}
+            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
+            FROM {self.weights_to_update_tbl}
+            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
+        """.format(self=self)
+        plpy.execute(update_query)
+        plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}".format(
+                                                        self.mst_weights_tbl,
+                                                        self.mst_current_schedule_tbl,
+                                                        self.weights_to_update_tbl))
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
new file mode 100644
index 0000000..090d858
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -0,0 +1,112 @@
+/* ----------------------------------------------------------------------- *//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *
+ * @file madlib_keras_model_selection.sql_in
+ *
+ * @brief SQL functions for model hopper distributed training
+ * @date August 2019
+ *
+ *
+ *//* ----------------------------------------------------------------------- */
+
+m4_include(`SQLCommon.m4')
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR
+) RETURNS VOID AS $$
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
+    with AOControl(False):
+        fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
+    state                      BYTEA,
+    dependent_var              BYTEA,
+    independent_var            BYTEA,
+    dependent_var_shape        INTEGER[],
+    independent_var_shape      INTEGER[],
+    model_architecture         TEXT,
+    compile_params             TEXT,
+    fit_params                 TEXT,
+    current_seg_id             INTEGER,
+    seg_ids                    INTEGER[],
+    images_per_seg             INTEGER[],
+    gpus_per_host              INTEGER,
+    segments_per_host          INTEGER,
+    prev_serialized_weights    BYTEA,
+    is_final_iteration         BOOLEAN
+) RETURNS BYTEA AS $$
+PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
+    return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+$$ LANGUAGE plpythonu
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
+
+DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step_multiple_model(
+    BYTEA,
+    BYTEA,
+    TEXT,
+    TEXT,
+    TEXT,
+    TEXT,
+    TEXT,
+    INTEGER,
+    INTEGER[],
+    INTEGER[],
+    INTEGER,
+    INTEGER,
+    BYTEA,
+    BOOLEAN);
+CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
+    /* dep_var */                BYTEA,
+    /* ind_var */                BYTEA,
+    /* dep_var_shape */          INTEGER[],
+    /* ind_var_shape */          INTEGER[],
+    /* model_architecture */     TEXT,
+    /* compile_params */         TEXT,
+    /* fit_params */             TEXT,
+    /* current_seg_id */         INTEGER,
+    /* seg_ids*/                 INTEGER[],
+    /* images_per_seg*/          INTEGER[],
+    /* gpus_per_host  */         INTEGER,
+    /* segments_per_host  */     INTEGER,
+    /* serialized_weights */     BYTEA,
+    /* is_final_iteration */     BOOLEAN
+)(
+    STYPE=BYTEA,
+    SFUNC=MADLIB_SCHEMA.fit_transition_multiple_model
+);
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index b198f02..e91435a 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -129,15 +129,12 @@ def get_image_count_per_seg_for_minibatched_data_from_db(table_name):
     mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
 
     shape_col = add_postfix(mb_dep_var_col, "_shape")
-    plpy.info(table_name)
-    plpy.info(shape_col)
 
     if is_platform_pg():
         res = plpy.execute(
             """ SELECT {0}::SMALLINT[] AS shape
                 FROM {1}
             """.format(shape_col, table_name))
-        plpy.info(res)
 
         images_per_seg = [sum(r['shape'][0] for r in res)]
         seg_ids = [0]
@@ -193,3 +190,39 @@ def parse_shape(shape):
     # Split on :, discard the first one [1:],
     # split each piece on ], take the first piece [0], convert to int
     return [int(a.split(']')[0]) for a in shape.split(':')[1:]]
+
+
+def query_model_configs(model_selection_table, model_selection_summary_table,
+    mst_key_col, model_arch_table_col):
+    msts_query = """
+                 SELECT * FROM {model_selection_table}
+                 ORDER BY {mst_key_col}
+                 """.format(**locals())
+    model_arch_table_query = """
+                             SELECT {model_arch_table_col}
+                             FROM {model_selection_summary_table}
+                             """.format(**locals())
+    msts = list(plpy.execute(msts_query))
+    model_arch_table = plpy.execute(model_arch_table_query)[0][model_arch_table_col]
+    return msts, model_arch_table
+
+def query_dist_keys(source_table, dist_key_col):
+    """ Read distinct keys from the source table """
+    dist_key_query = """
+                     SELECT DISTINCT({dist_key_col}) FROM {source_table}
+                     ORDER BY {dist_key_col}
+                     """.format(dist_key_col=dist_key_col,
+                                source_table=source_table)
+    res = list(plpy.execute(dist_key_query))
+    res = [x[dist_key_col] for x in res]
+    return res
+
+def query_weights(model_output_table, model_weights_col, mst_key_col, mst_key):
+    mlp_weights_query = """
+                        SELECT {model_weights_col}, {mst_key_col}
+                        FROM {model_output_table}
+                        WHERE {mst_key_col} = {mst_key}
+                        """.format(**locals())
+
+    res = plpy.execute(mlp_weights_query)
+    return res[0][model_weights_col]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
index e642034..5d2cbbf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
@@ -22,7 +22,15 @@ from madlib_keras_validator import MstLoaderInputValidator
 from utilities.control import MinWarning
 from utilities.utilities import add_postfix
 from madlib_keras_wrapper import convert_string_of_args_to_dict
+from keras_model_arch_table import ModelArchSchema
 
+class ModelSelectionSchema:
+    MST_KEY = 'mst_key'
+    MODEL_ID = ModelArchSchema.MODEL_ID
+    MODEL_ARCH_TABLE = 'model_arch_table'
+    COMPILE_PARAMS = 'compile_params'
+    FIT_PARAMS = 'fit_params'
+    col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR')
 
 @MinWarning("warning")
 class MstLoader():
@@ -66,7 +74,7 @@ class MstLoader():
             compile_params_list)
         self.fit_params_list = self.params_preprocessed(fit_params_list)
 
-	self.msts = []
+        self.msts = []
 
         self.find_combinations()
 
@@ -88,22 +96,22 @@ class MstLoader():
             list: The preprocessed list of strings.
         """
 
-	dict_dedup = {}
-	for string in list_strs:
-	    d = convert_string_of_args_to_dict(string)
-	    hash_tuple = tuple( '{0} = {1}'\
-		.format(x, d[x]) for x in sorted(d.keys()))
-	    dict_dedup[hash_tuple] = string
+        dict_dedup = {}
+        for string in list_strs:
+            d = convert_string_of_args_to_dict(string)
+            hash_tuple = tuple( '{0} = {1}'\
+            .format(x, d[x]) for x in sorted(d.keys()))
+            dict_dedup[hash_tuple] = string
 
-	return dict_dedup.values()
+        return dict_dedup.values()
 
     def find_combinations(self):
         """Backtracking helper for generating the combinations.
         """
         param_grid = OrderedDict([
-            ('model_arch_id', self.model_arch_id_list),
-            ('compile_params', self.compile_params_list),
-            ('fit_params', self.fit_params_list)
+            (ModelSelectionSchema.MODEL_ID, self.model_arch_id_list),
+            (ModelSelectionSchema.COMPILE_PARAMS, self.compile_params_list),
+            (ModelSelectionSchema.FIT_PARAMS, self.fit_params_list)
         ])
 
         def find_combinations_helper(msts, p, i):
@@ -121,13 +129,17 @@ class MstLoader():
         """
         create_query = """
                         CREATE TABLE {self.model_selection_table} (
-                            mst_key SERIAL,
-                            model_arch_id INTEGER,
-                            compile_params VARCHAR,
-                            fit_params VARCHAR,
-                            unique (model_arch_id, compile_params, fit_params)
+                            {mst_key} SERIAL,
+                            {model_arch_id} INTEGER,
+                            {compile_params} VARCHAR,
+                            {fit_params} VARCHAR,
+                            unique ({model_arch_id}, {compile_params}, {fit_params})
                         );
-                       """.format(self=self)
+                       """.format(self=self,
+                                  mst_key=ModelSelectionSchema.MST_KEY,
+                                  model_arch_id=ModelSelectionSchema.MODEL_ID,
+                                  compile_params=ModelSelectionSchema.COMPILE_PARAMS,
+                                  fit_params=ModelSelectionSchema.FIT_PARAMS)
         with MinWarning('warning'):
             plpy.execute(create_query)
 
@@ -136,9 +148,10 @@ class MstLoader():
         """
         create_query = """
                         CREATE TABLE {self.model_selection_summary_table} (
-                            model_arch_table VARCHAR
+                            {model_arch_table} VARCHAR
                         );
-                       """.format(self=self)
+                       """.format(self=self,
+                                  model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE)
         with MinWarning('warning'):
             plpy.execute(create_query)
 
@@ -146,30 +159,34 @@ class MstLoader():
         """Insert every thing in self.msts into the mst table.
         """
         for mst in self.msts:
-            model_arch_id = mst['model_arch_id']
-            compile_params = mst['compile_params']
-            fit_params = mst['fit_params']
+            model_arch_id = mst[ModelSelectionSchema.MODEL_ID]
+            compile_params = mst[ModelSelectionSchema.COMPILE_PARAMS]
+            fit_params = mst[ModelSelectionSchema.FIT_PARAMS]
             insert_query = """
                             INSERT INTO
                                 {self.model_selection_table}(
-                                    model_arch_id,
-                                    compile_params,
-                                    fit_params
+                                    {model_arch_id_col},
+                                    {compile_params_col},
+                                    {fit_params_col}
                                 )
                             VALUES (
                                 {model_arch_id},
                                 $${compile_params}$$,
                                 $${fit_params}$$
                             )
-                           """.format(**locals())
+                           """.format(model_arch_id_col=ModelSelectionSchema.MODEL_ID,
+                                      compile_params_col=ModelSelectionSchema.COMPILE_PARAMS,
+                                      fit_params_col=ModelSelectionSchema.FIT_PARAMS,
+                                      **locals())
             plpy.execute(insert_query)
         insert_summary_query = """
                         INSERT INTO
                             {self.model_selection_summary_table}(
-                                model_arch_table
+                                {model_arch_table_name}
                         )
                         VALUES (
                             $${self.model_arch_table}$$
                         )
-                       """.format(**locals())
+                       """.format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE,
+                                  **locals())
         plpy.execute(insert_summary_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index c73f919..4d94936 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -244,8 +244,7 @@ def internal_keras_predict(independent_var, model_architecture, model_data,
         if model_key not in SD:
             set_keras_session(device_name, gpus_per_host, segments_per_host)
             model = model_from_json(model_architecture)
-            model_shapes = get_model_shapes(model)
-            set_model_weights(model, device_name, model_data, model_shapes)
+            set_model_weights(model, model_data)
 
             SD[model_key] = model
             SD[row_count_key] = 0
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
index b2ccce4..d70a2f8 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
@@ -17,6 +17,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+from utilities.utilities import _assert
 
 # TODO
 # 1. Current serializing logic
@@ -53,7 +54,10 @@ def get_image_count_from_state(state):
     and weights
     :return: image count as float
     """
-    image_count , _  = deserialize_as_image_1d_weights(state)
+    if not state:
+        image_count = 0
+    else:
+        image_count , _  = deserialize_as_image_1d_weights(state)
     return image_count
 
 def get_serialized_1d_weights_from_state(state):
@@ -155,6 +159,13 @@ def deserialize_as_nd_weights(model_weights_serialized, model_shapes):
 
     i, j, model_weights = 0, 0, []
     model_weights_serialized = np.fromstring(model_weights_serialized, dtype=np.float32)
+
+    total_model_shape = \
+        sum([reduce(lambda x, y: x * y, ls) for ls in model_shapes])
+    total_weights_shape = model_weights_serialized.size
+    _assert(total_model_shape == total_weights_shape,
+            "Number of elements in model weights({0}) doesn't match model({1})."\
+                .format(total_weights_shape, total_model_shape))
     while j < len(model_shapes):
         next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
         weight_arr_portion = model_weights_serialized[i:next_pointer]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 6536842..01dc490 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -33,6 +33,7 @@ from madlib_keras_helper import MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
 from madlib_keras_helper import MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
 from madlib_keras_helper import METRIC_TYPE_COLNAME
 from madlib_keras_helper import parse_shape
+from madlib_keras_helper import query_model_configs
 
 from utilities.minibatch_validation import validate_bytea_var_for_minibatch
 from utilities.utilities import _assert
@@ -231,13 +232,11 @@ class InputValidator:
                 module_name, model_summary_table, cols_to_check_for))
 
 
-
-
-class FitInputValidator:
+class FitCommonValidator(object):
     def __init__(self, source_table, validation_table, output_model_table,
                  model_arch_table, model_arch_id, dependent_varname,
                  independent_varname, num_iterations,
-                 metrics_compute_frequency, warm_start):
+                 metrics_compute_frequency, warm_start, module_name):
         self.source_table = source_table
         self.validation_table = validation_table
         self.output_model_table = output_model_table
@@ -257,10 +256,10 @@ class FitInputValidator:
         if self.output_model_table:
             self.output_summary_model_table = add_postfix(
                 self.output_model_table, "_summary")
-        self.module_name = 'madlib_keras_fit'
-        self._validate_input_args()
+        self.module_name = module_name
+        self._validate_common_args()
 
-    def _validate_input_args(self):
+    def _validate_common_args(self):
         _assert(self.num_iterations > 0,
             "{0}: Number of iterations cannot be < 1.".format(self.module_name))
         _assert(self._is_valid_metrics_compute_frequency(),
@@ -281,8 +280,6 @@ class FitInputValidator:
                                          self.dependent_varname)
 
         self._validate_validation_table()
-        InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
-            self.model_arch_id)
         if self.warm_start:
             input_tbl_valid(self.output_model_table, self.module_name)
             input_tbl_valid(self.output_summary_model_table, self.module_name)
@@ -353,6 +350,60 @@ class FitInputValidator:
                 input_shape, 2)
 
 
+class FitInputValidator(FitCommonValidator):
+    def __init__(self, source_table, validation_table, output_model_table,
+                 model_arch_table, model_arch_id, dependent_varname,
+                 independent_varname, num_iterations,
+                 metrics_compute_frequency, warm_start):
+
+        self.module_name = 'madlib_keras_fit'
+        super(FitInputValidator, self).__init__(source_table,
+                                                validation_table,
+                                                output_model_table,
+                                                model_arch_table,
+                                                model_arch_id,
+                                                dependent_varname,
+                                                independent_varname,
+                                                num_iterations,
+                                                metrics_compute_frequency,
+                                                warm_start,
+                                                self.module_name)
+        InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
+            self.model_arch_id)
+
+class FitMultipleInputValidator(FitCommonValidator):
+    def __init__(self, source_table, validation_table, output_model_table,
+                 model_selection_table, model_selection_summary_table, dependent_varname,
+                 independent_varname, num_iterations, model_info_table, mst_key_col,
+                 model_arch_table_col, metrics_compute_frequency, warm_start):
+
+        self.module_name = 'madlib_keras_fit_multiple'
+
+        input_tbl_valid(model_selection_table, self.module_name)
+        input_tbl_valid(model_selection_summary_table, self.module_name,
+                        error_suffix_str="Please ensure that the model selection table ({0}) "
+                                         "has been created by "
+                                         "load_model_selection_table().".format(
+                                            model_selection_table))
+        self.msts, self.model_arch_table = query_model_configs(
+            model_selection_table, model_selection_summary_table,
+            mst_key_col, model_arch_table_col)
+        output_tbl_valid(model_info_table, self.module_name)
+        super(FitMultipleInputValidator, self).__init__(source_table,
+                                                        validation_table,
+                                                        output_model_table,
+                                                        self.model_arch_table,
+                                                        None,
+                                                        dependent_varname,
+                                                        independent_varname,
+                                                        num_iterations,
+                                                        metrics_compute_frequency,
+                                                        warm_start,
+                                                        self.module_name)
+
+
+
+
 
 class MstLoaderInputValidator():
     def __init__(self,
@@ -421,3 +472,4 @@ class MstLoaderInputValidator():
         output_tbl_valid(self.model_selection_table, self.module_name)
         output_tbl_valid(self.model_selection_summary_table, self.module_name)
 
+
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 1961a00..73df519 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -71,14 +71,25 @@ def get_device_name_and_set_cuda_env(gpus_per_host, seg):
 
 def set_keras_session(device_name, gpus_per_host, segments_per_host):
     with K.tf.device(device_name):
-        config = K.tf.ConfigProto()
-        if gpus_per_host > 0:
-            memory_fraction = get_gpu_memory_fraction(gpus_per_host, segments_per_host)
-            config.gpu_options.allow_growth = False
-            config.gpu_options.per_process_gpu_memory_fraction = memory_fraction
-        session = K.tf.Session(config=config)
+        session = get_keras_session(device_name, gpus_per_host, segments_per_host)
         K.set_session(session)
 
+def get_keras_session(device_name, gpus_per_host, segments_per_host):
+    config = K.tf.ConfigProto()
+    if gpus_per_host > 0:
+        memory_fraction = get_gpu_memory_fraction(gpus_per_host, segments_per_host)
+        config.gpu_options.allow_growth = False
+        config.gpu_options.per_process_gpu_memory_fraction = memory_fraction
+    session = tf.Session(config=config)
+    return session
+
+def clear_keras_session(sess = None):
+    if sess is None:
+        sess = K.get_session()
+    K.clear_session()
+    sess.close()
+
+
 def get_gpu_memory_fraction(gpus_per_host, segments_per_host):
     """
     We cap the gpu memory usage to 90% of the total available gpu memory.
@@ -89,11 +100,6 @@ def get_gpu_memory_fraction(gpus_per_host, segments_per_host):
     """
     return 0.9 / ceil(1.0 * segments_per_host / gpus_per_host)
 
-def clear_keras_session():
-    sess = K.get_session()
-    K.clear_session()
-    sess.close()
-
 def get_model_shapes(model):
     model_shapes = []
     for a in model.get_weights():
@@ -103,20 +109,19 @@ def get_model_shapes(model):
 def compile_and_set_weights(segment_model, compile_params, device_name,
                             serialized_weights):
     model_shapes = get_model_shapes(segment_model)
-    with K.tf.device(device_name):
-        compile_model(segment_model, compile_params)
-        model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
-            serialized_weights, model_shapes)
-        segment_model.set_weights(model_weights)
+    compile_model(segment_model, compile_params)
+    model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
+        serialized_weights, model_shapes)
+    segment_model.set_weights(model_weights)
 
 # TODO: This can be refactored to be part of compile_and_set_weights(),
 # by making compile_params an optional param in that function. Doing that
 # now might create more merge conflicts with other JIRAs, so get to this later.
-def set_model_weights(segment_model, device_name, serialized_weights, model_shapes):
-    with K.tf.device(device_name):
-        model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
-            serialized_weights, model_shapes)
-        segment_model.set_weights(model_weights)
+def set_model_weights(segment_model, serialized_weights):
+    model_shapes = get_model_shapes(segment_model)
+    model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
+        serialized_weights, model_shapes)
+    segment_model.set_weights(model_weights)
 
 """
 Used to convert compile_params and fit_params to actual argument dictionaries
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 41fe257..c4f5b7a 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -19,39 +19,14 @@
  *
  *//* ---------------------------------------------------------------------*/
 
--- MST table generation tests
+m4_include(`SQLCommon.m4')
 
--- First set up model arch table, to use as input
-DROP TABLE IF EXISTS iris_model_arch;
--- NOTE: The seed is set to 0 for every layer.
-SELECT load_keras_model('iris_model_arch',  -- Output table,
-$$
-{
-"class_name": "Sequential",
-"keras_version": "2.1.6",
-"config":
-    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling",
-    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": "fan_avg"}},
-    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
-    "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true,
-    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
-    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": true,
-    "activity_regularizer": null}}, {"class_name": "Dense",
-    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
-    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": "fan_avg"}},
-    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
-    "bias_constraint": null, "activation": "relu", "trainable": true, "kernel_regularizer": null,
-    "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "use_bias": true,
-    "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer":
-    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0,
-    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": null,
-    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
-    "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
-    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": null}}],
-    "backend": "tensorflow"}
-$$
-);
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
 
+-- MST table generation tests
 -- Valid inputs should pass and yield 6 msts in the table
 DROP TABLE IF EXISTS mst_table, mst_table_summary;
 SELECT load_model_selection_table(
@@ -112,7 +87,7 @@ SELECT assert(trap_error($TRAP$
     );
 $TRAP$)=1, 'Should error out if the provided parameters are not valid');
 
--- Must deduplicate, options with extrac white spaces should not be considered
+-- Must deduplicate, options with extra white spaces should not be considered
 -- as distinct params.
 
 DROP TABLE IF EXISTS mst_table, mst_table_summary;
@@ -158,3 +133,180 @@ SELECT assert(
 )
 FROM mst_table;
 
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+-- Multiple models test
+-- Prepare model selection table with three rows
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=50, epochs=1$$
+    ]
+);
+-- Prepare model selection table with only one row
+DROP TABLE IF EXISTS mst_table_1row, mst_table_1row_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table_1row',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$
+    ]
+);
+-- Prepare model selection table with four rows
+DROP TABLE IF EXISTS mst_table_4row, mst_table_4row_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table_4row',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$,
+        $$batch_size=32, epochs=1$$
+    ]
+);
+
+-- Test when number of configs(3) equals number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table',
+	6,
+	0
+);
+
+SELECT assert(
+        model_arch_table = 'iris_model_arch' AND
+        model_info = 'iris_multiple_model_info' AND
+        source_table = 'iris_data_packed' AND
+        model = 'iris_multiple_model' AND
+        dependent_varname = 'class_text' AND
+        independent_varname = 'attributes' AND
+        madlib_version is NOT NULL AND
+        num_iterations = 6 AND
+        num_classes = 3 AND
+        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+        normalizing_const = 1,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        fit_params = $MAD$batch_size=50, epochs=1$MAD$::text AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 6 AND
+        array_upper(training_loss, 1) = 6 AND
+        array_upper(metrics_elapsed_time, 1) = 6,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(
+  training_loss[6]-training_loss[1] < 0 AND
+  training_metrics[6]-training_metrics[1] > 0,
+    'The loss and accuracy should have improved with more iterations.'
+)
+FROM iris_multiple_model_info
+WHERE compile_params like '%lr=0.01%';
+
+-- Test when number of configs(1) is less than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table_1row',
+	3,
+	0
+);
+
+SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 3 AND
+        array_upper(training_loss, 1) = 3 AND
+        array_upper(metrics_elapsed_time, 1) = 3,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+
+-- Test when number of configs(4) larger than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table_4row',
+	3,
+	0
+);
+
+SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 3 AND
+        array_upper(training_loss, 1) = 3 AND
+        array_upper(metrics_elapsed_time, 1) = 3,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
+AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+!>)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
index 7a27c7b..b8b3780 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
@@ -274,12 +274,15 @@ FROM cifar10_predict;
 -- Predict with correctly shaped data, must go thru.
 -- Update output_summary table to reflect
 -- class_values, num_classes and model_arch_id for shaped data
-UPDATE keras_saved_out
-SET model_arch = (SELECT model_arch from model_arch where model_id = 3);
-UPDATE keras_saved_out_summary
-SET model_arch_id = 3,
-    num_classes = 3,
-    class_values = ARRAY[0,NULL,NULL]::INTEGER[];
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_test_shape_batched',
+    'keras_saved_out',
+    'model_arch',
+    3,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);
 
 DROP TABLE IF EXISTS cifar10_predict;
 SELECT madlib_keras_predict(
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index af48618..7161dc8 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -86,12 +86,39 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
     def _test_fit_transition_first_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `clear_keras_session`
+        # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
+        self.subject.K.clear_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         starting_image_count = 0
         ending_image_count = len(self.dependent_var_int)
+
+        # last iteration Call
+        previous_state = np.array(self.model_weights, dtype=np.float32)
+
+        k = {'SD': {}}
+
+        new_state = self.subject.fit_transition(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.all_seg_ids, self.total_images_per_seg, 0, 4,
+            previous_state.tostring(), True, **k)
+        state = np.fromstring(new_state, dtype=np.float32)
+        image_count = state[0]
+        weights = np.rint(state[1:]).astype(np.int)
+        self.assertEqual(ending_image_count, image_count)
+        # weights should not be modified yet
+        self.assertTrue((self.model_weights == weights).all())
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the first buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue(k['SD']['segment_model'])
+
+        # Non-last iteration Call
+        self.subject.K.set_session.reset_mock()
+        self.subject.K.clear_session.reset_mock()
         previous_state = np.array(self.model_weights, dtype=np.float32)
 
         k = {'SD' : {}}
@@ -101,42 +128,141 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.all_seg_ids, self.total_images_per_seg, 0, 4,
-            previous_state.tostring(), **k)
+            previous_state.tostring(), False, **k)
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
         # weights should not be modified yet
         self.assertTrue((self.model_weights == weights).all())
-        # set_session must get called ONLY once, when its the first buffer
+        # set_session is always called
         self.assertEqual(1, self.subject.K.set_session.call_count)
         # Clear session and sess.close must not get called for the first buffer
-        self.assertEqual(0, self.subject.clear_keras_session.call_count)
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue(k['SD']['segment_model'])
+
+    def test_fit_transition_multiple_model_first_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+        starting_image_count = 0
+        ending_image_count = len(self.dependent_var_int)
+
+        previous_state = [starting_image_count]
+        previous_state.extend(self.model_weights)
+        previous_state = np.array(previous_state, dtype=np.float32)
+
+        k = {'SD': {}}
+
+        new_state = self.subject.fit_transition(
+            None, self.dependent_var, self.independent_var ,
+            self.dependent_var_shape, self.independent_var_shape, self.model.to_json(),
+            self.compile_params, self.fit_params, 0, self.all_seg_ids,
+            self.total_images_per_seg, 0, 4, previous_state.tostring(), True,
+            True, **k)
+        state = np.fromstring(new_state, dtype=np.float32)
+        image_count = state[0]
+        weights = np.rint(state[1:]).astype(np.int)
+        self.assertEqual(ending_image_count, image_count)
+        # weights should not be modified yet
+        self.assertTrue((self.model_weights == weights).all())
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session must not be called for the first buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
     def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `clear_keras_session`
+        # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
+        self.subject.K.clear_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
+        # last iteration Call
+
+        state = [starting_image_count]
+        state.extend(self.model_weights)
+        state = np.array(state, dtype=np.float32)
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+        k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
+
+        new_state = self.subject.fit_transition(
+            state.tostring(), self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
+            self.total_images_per_seg, 0, 4, 'dummy_previous_state', True, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        image_count = state[0]
+        weights = np.rint(state[1:]).astype(np.int)
+        self.assertEqual(ending_image_count, image_count)
+        # weights should not be modified yet
+        self.assertTrue((self.model_weights == weights).all())
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the middle buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+
+        # Non-last iteration Call
+
+        self.subject.K.set_session.reset_mock()
+        self.subject.K.clear_session.reset_mock()
+        state = [starting_image_count]
+        state.extend(self.model_weights)
+        state = np.array(state, dtype=np.float32)
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+        k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
+
+        new_state = self.subject.fit_transition(
+            state.tostring(), self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
+            self.total_images_per_seg, 0, 4, 'dummy_previous_state', False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        image_count = state[0]
+        weights = np.rint(state[1:]).astype(np.int)
+        self.assertEqual(ending_image_count, image_count)
+        # weights should not be modified yet
+        self.assertTrue((self.model_weights == weights).all())
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the middle buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+
+    def test_fit_transition_multiple_model_middle_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        # last iteration Call
+
         state = [starting_image_count]
         state.extend(self.model_weights)
         state = np.array(state, dtype=np.float32)
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
                                              '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model': self.model}}
+        k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
 
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', **k)
+            self.total_images_per_seg, 0, 4, 'dummy_previous_state', True, True, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -144,21 +270,22 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(ending_image_count, image_count)
         # weights should not be modified yet
         self.assertTrue((self.model_weights == weights).all())
-        # set_session must get called ONLY once, when its the first buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
         # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.clear_keras_session.call_count)
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
 
     def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `clear_keras_session`
+        # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
+        self.subject.K.clear_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
+        # last iteration Call
         state = [starting_image_count]
         state.extend(self.model_weights)
         state = np.array(state, dtype=np.float32)
@@ -167,12 +294,12 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
                                              '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model}}
+        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', **k)
+            self.total_images_per_seg, 0, 4, 'dummy_previous_state', True, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -180,11 +307,75 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(ending_image_count, image_count)
         # weights should be multiplied by final image count
         self.assertTrue((multiplied_weights == weights).all())
-        # set_session must be not be called in transition func for PG
-        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
         # Clear session and sess.close must get called for the last buffer in gpdb,
         #  but not in postgres
-        self.assertEqual(1, self.subject.clear_keras_session.call_count)
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+
+        # Non-last iteration Call
+        self.subject.K.set_session.reset_mock()
+        self.subject.K.clear_session.reset_mock()
+        state = [starting_image_count]
+        state.extend(self.model_weights)
+        state = np.array(state, dtype=np.float32)
+
+        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+        new_state = self.subject.fit_transition(
+            state.tostring(), self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
+            self.total_images_per_seg, 0, 4, 'dummy_previous_state', False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        image_count = state[0]
+        weights = np.rint(state[1:]).astype(np.int)
+        self.assertEqual(ending_image_count, image_count)
+        # weights should be multiplied by final image count
+        self.assertTrue((multiplied_weights == weights).all())
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer in gpdb,
+        #  but not in postgres
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+
+    def test_fit_transition_multiple_model_last_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        # last iteration Call
+        state = [starting_image_count]
+        state.extend(self.model_weights)
+        state = np.array(state, dtype=np.float32)
+
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+        new_state = self.subject.fit_transition(
+            state.tostring(), self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape, self.model.to_json(),
+            None, self.fit_params, 0, self.all_seg_ids, self.total_images_per_seg,
+            0, 4, 'dummy_previous_state', True, True, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        image_count = state[0]
+        weights = np.rint(state[1:]).astype(np.int)
+        self.assertEqual(ending_image_count, image_count)
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer in gpdb,
+        #  but not in postgres
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
 
     def test_fit_transition_first_buffer_pass_pg(self):
         self._test_fit_transition_first_buffer_pass(True)
@@ -350,10 +541,8 @@ class InternalKerasPredictTestCase(unittest.TestCase):
 
     def test_predict_first_image_pass_gpdb(self):
         self.subject.is_platform_pg = Mock(return_value = False)
-        model_weights = [1,2,3,4,5,6]
-        serialized_weights = [0, 0, 0] # not used
-        serialized_weights.extend(model_weights)
-        serialized_weights = np.array(serialized_weights, dtype=np.float32).tostring()
+        model_weights = [1, 2, 3, 4]
+        serialized_weights = np.array(model_weights, dtype=np.float32).tostring()
 
         k = {'SD': {}}
         is_response = True
@@ -830,7 +1019,7 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         self.assertIn('invalid optimizer', str(error.exception))
 
 
-class MadlibKerasFitInputValidatorTestCase(unittest.TestCase):
+class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
         patches = {
@@ -850,31 +1039,31 @@ class MadlibKerasFitInputValidatorTestCase(unittest.TestCase):
 
 
     def test_is_valid_metrics_compute_frequency_True_None(self):
-        self.subject.FitInputValidator._validate_input_args = Mock()
-        obj = self.subject.FitInputValidator(
+        self.subject.FitCommonValidator._validate_common_args = Mock()
+        obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, None, False)
+            'dep_varname', 'independent_varname', 5, None, False, 'module_name')
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_True_num(self):
-        self.subject.FitInputValidator._validate_input_args = Mock()
-        obj = self.subject.FitInputValidator(
+        self.subject.FitCommonValidator._validate_common_args = Mock()
+        obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, 3, False)
+            'dep_varname', 'independent_varname', 5, 3, False, 'module_name')
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_zero(self):
-        self.subject.FitInputValidator._validate_input_args = Mock()
-        obj = self.subject.FitInputValidator(
+        self.subject.FitCommonValidator._validate_common_args = Mock()
+        obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, 0, False)
+            'dep_varname', 'independent_varname', 5, 0, False, 'module_name')
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_greater(self):
-        self.subject.FitInputValidator._validate_input_args = Mock()
-        obj = self.subject.FitInputValidator(
+        self.subject.FitCommonValidator._validate_common_args = Mock()
+        obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, 6, False)
+            'dep_varname', 'independent_varname', 5, 6, False, 'module_name')
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
 
@@ -1017,17 +1206,25 @@ class MadlibSerializerTestCase(unittest.TestCase):
         invalid_model_weights = np.array([1,2], dtype=np.float32)
         dummy_model_shape = [(2, 1, 1, 1), (1,)]
 
-        # we expect keras failure(ValueError) because we cannot reshape
+        # we expect raise exception because we cannot reshape
         # model weights of size 0 into shape (2,2,3,1)
-        with self.assertRaises(ValueError):
+        with self.assertRaises(plpy.PLPYException) as error:
             self.subject.deserialize_as_nd_weights(invalid_model_weights.tostring(),
                                                    dummy_model_shape)
 
         invalid_model_weights = np.array([1,2,3,4], dtype=np.float32)
         dummy_model_shape = [(2, 2, 3, 1), (1,)]
-        # we expect keras failure(ValueError) because we cannot reshape
+        # we expect raise exception because we cannot reshape
         # model weights of size 2 into shape (2,2,3,1)
-        with self.assertRaises(ValueError):
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject.deserialize_as_nd_weights(invalid_model_weights.tostring(),
+                                                   dummy_model_shape)
+
+        invalid_model_weights = np.array([0,1,2,3,4], dtype=np.float32)
+        dummy_model_shape = [(2, 1), (1,)]
+        # we expect to raise exception because we cannot reshape
+        # model weights of size 2 into shape (1,)
+        with self.assertRaises(plpy.PLPYException) as error:
             self.subject.deserialize_as_nd_weights(invalid_model_weights.tostring(),
                                                    dummy_model_shape)
 
@@ -1141,20 +1338,21 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
 
     def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg):
         self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
+        self.subject.K.clear_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         starting_image_count = 0
         ending_image_count = len(self.dependent_var_int)
 
+        # last iteration call
+
         k = {'SD' : {}}
         state = [0,0,0]
-
         new_state = self.subject.internal_keras_eval_transition(
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
             self.serialized_weights, self.compile_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, **k)
+            self.total_images_per_seg, 0, 3, True, **k)
 
         agg_loss, agg_accuracy, image_count = new_state
 
@@ -1165,19 +1363,44 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
         self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
         # Clear session and sess.close must not get called for the first buffer
-        self.assertEqual(0, self.subject.clear_keras_session.call_count)
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue(k['SD']['segment_model'])
+
+        # Non-final call
+
+        self.subject.K.set_session.reset_mock()
+        self.subject.K.clear_session.reset_mock()
+        k = {'SD' : {}}
+        state = [0,0,0]
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.serialized_weights, self.compile_params,
+            0, self.all_seg_ids, self.total_images_per_seg, 0, 3, False, **k)
+        agg_loss, agg_accuracy, image_count = new_state
+
+        self.assertEqual(ending_image_count, image_count)
+        # set_session must not get called for the first buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # loss and accuracy should be unchanged
+        self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
+        # Clear session and sess.close must not get called for the first buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
     def _test_internal_keras_eval_transition_middle_buffer(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `clear_keras_session`
+        # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
+        self.subject.K.clear_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
+        # last iteration call
+
         k = {'SD' : {}}
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
@@ -1185,34 +1408,67 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
 
         state = [self.loss * starting_image_count, self.accuracy * starting_image_count, starting_image_count]
         k['SD']['segment_model'] = self.model
+        k['SD']['sess'] = Mock()
 
         new_state = self.subject.internal_keras_eval_transition(
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
             'dummy_model_data', None, 0,self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, **k)
+            self.total_images_per_seg, 0, 3, True, **k)
 
         agg_loss, agg_accuracy, image_count = new_state
 
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called in first buffer, not here
-        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+         # loss and accuracy should be unchanged
+        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
+        # Clear session and sess.close must not get called for the middle buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+
+        # Non-last iteration call
+
+        self.subject.K.set_session.reset_mock()
+        self.subject.K.clear_session.reset_mock()
+        k = {'SD' : {}}
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+
+        state = [self.loss * starting_image_count, self.accuracy * starting_image_count, starting_image_count]
+        k['SD']['segment_model'] = self.model
+        k['SD']['sess'] = Mock()
+
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
+            'dummy_model_data', None, 0,self.all_seg_ids,
+            self.total_images_per_seg, 0, 3, False, **k)
+
+        agg_loss, agg_accuracy, image_count = new_state
+
+        self.assertEqual(ending_image_count, image_count)
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
          # loss and accuracy should be unchanged
         self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
         self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
         # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.clear_keras_session.call_count)
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
 
     def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `clear_keras_session`
+        # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
+        self.subject.K.clear_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
+
         k = {'SD' : {}}
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
@@ -1223,24 +1479,60 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
                  starting_image_count]
 
         k['SD']['segment_model'] = self.model
+        k['SD']['sess'] = Mock()
+
         new_state = self.subject.internal_keras_eval_transition(
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
             'dummy_model_data', None, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, **k)
+            self.total_images_per_seg, 0, 3, True, **k)
 
         agg_loss, agg_accuracy, image_count = new_state
 
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called in first buffer, not here
-        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
         # loss and accuracy should be unchanged
         self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
         self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
         # Clear session and sess.close must get called for the last buffer in gpdb,
         #  but not in postgres
-        self.assertEqual(1, self.subject.clear_keras_session.call_count)
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+
+        # Non-final call
+
+        self.subject.K.set_session.reset_mock()
+        self.subject.K.clear_session.reset_mock()
+        k = {'SD' : {}}
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+
+        state = [self.loss * starting_image_count,
+                 self.accuracy * starting_image_count,
+                 starting_image_count]
+
+        k['SD']['segment_model'] = self.model
+        k['SD']['sess'] = Mock()
+
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
+            'dummy_model_data', None, 0, self.all_seg_ids,
+            self.total_images_per_seg, 0, 3, False, **k)
+
+        agg_loss, agg_accuracy, image_count = new_state
+
+        self.assertEqual(ending_image_count, image_count)
+        # set_session is always called
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # loss and accuracy should be unchanged
+        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
+        # Clear session and sess.close must not get called in non-final iterations
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
 
     def test_internal_keras_eval_transition_first_buffer_pg(self):
         self._test_internal_keras_eval_transition_first_buffer(True)
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 8f5b2ff..687566a 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -1112,3 +1112,17 @@ def create_table_drop_cols(source_table, out_table, cols_to_drop, **kwargs):
                    out_table=out_table,
                    source_table=source_table))
 # ------------------------------------------------------------------------------
+
+
+def rotate(l, n):
+    """Summary
+    Rotate the list l to right(the index increasing direction) for n elements.
+    Args:
+        l (list): The input list to rotate
+        n (integer): The number of elements to rotate
+
+    Returns:
+        list: The rotated list
+    """
+    return l[-n:] + l[:-n]
+# ------------------------------------------------------------------------------