You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2019/12/09 21:49:46 UTC

[madlib] branch master updated: DL: Add asymmetric cluster support for fit and evaluate

This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d2558f  DL: Add asymmetric cluster support for fit and evaluate
9d2558f is described below

commit 9d2558fb25de689fde7a8584b2260b894d4ae06c
Author: Orhan Kislal <ok...@apache.org>
AuthorDate: Fri Dec 6 19:57:58 2019 -0500

    DL: Add asymmetric cluster support for fit and evaluate
    
    JIRA: MADLIB-1393
    
    This commit updates the fit and evaluate transition functions to support
    asymmetric clusters.
    
    The gpus_per_host variable is changed to use_gpus. This boolean variable is
    used to indicate if CPUs or GPUs are used for keras operations. The gpu
    configuration to use is read from the preprocessor summary table.
    
    Closes #462
    
    Co-authored-by: Ekta Khanna <ek...@pivotal.io>
---
 .../deep_learning/input_data_preprocessor.py_in    |   5 +-
 .../modules/deep_learning/madlib_keras.py_in       | 146 ++++++++-------
 .../modules/deep_learning/madlib_keras.sql_in      | 172 +++++++++---------
 .../madlib_keras_fit_multiple_model.py_in          |  78 +++++----
 .../madlib_keras_fit_multiple_model.sql_in         |  58 +++---
 .../deep_learning/madlib_keras_helper.py_in        |  61 ++++++-
 .../deep_learning/madlib_keras_predict.py_in       |   4 +-
 .../deep_learning/madlib_keras_validator.py_in     |  41 ++++-
 .../deep_learning/madlib_keras_wrapper.py_in       |  36 +++-
 .../test/madlib_keras_evaluate.sql_in              |   8 +-
 .../deep_learning/test/madlib_keras_fit.sql_in     |   4 +-
 .../test/madlib_keras_model_averaging_e2e.sql_in   |   8 +-
 .../test/madlib_keras_model_selection.sql_in       |   8 +-
 .../test/madlib_keras_model_selection_e2e.sql_in   |   4 +-
 .../deep_learning/test/madlib_keras_predict.sql_in |   2 +-
 .../test/madlib_keras_transfer_learning.sql_in     |  12 +-
 .../test/unit_tests/test_madlib_keras.py_in        | 195 +++++++++++++++------
 17 files changed, 539 insertions(+), 303 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
index 1acd136..757a5bc 100644
--- a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
+++ b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
@@ -292,7 +292,7 @@ class InputDataPreprocessorDL(object):
                 where_clause = "WHERE gp_segment_id=ANY({self.gpu_config})".format(**locals())
                 join_clause = gpu_join_clause.format(**locals())
 
-            elif self.distribution_rules == 'all_segments':
+            elif self.distribution_rules == DEFAULT_GPU_CONFIG:
 
                 self.distribution_rules = '$__madlib__$all_segments$__madlib__$'
                 where_clause = ''
@@ -384,7 +384,7 @@ class InputDataPreprocessorDL(object):
                 {self.normalizing_const}::{FLOAT32_SQL_TYPE} AS {normalizing_const_colname},
                 {self.num_classes} AS {num_classes_colname},
                 {self.distribution_rules} AS distribution_rules,
-                {self.gpu_config} AS __internal_gpu_config__
+                {self.gpu_config} AS {internal_gpu_config}
             """.format(self=self, class_level_str=class_level_str,
                        dependent_varname_colname=DEPENDENT_VARNAME_COLNAME,
                        independent_varname_colname=INDEPENDENT_VARNAME_COLNAME,
@@ -392,6 +392,7 @@ class InputDataPreprocessorDL(object):
                        class_values_colname=CLASS_VALUES_COLNAME,
                        normalizing_const_colname=NORMALIZING_CONST_COLNAME,
                        num_classes_colname=NUM_CLASSES_COLNAME,
+                       internal_gpu_config=INTERNAL_GPU_CONFIG,
                        FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE)
         plpy.execute(query)
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 5e35b46..9ed02d1 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -34,8 +34,10 @@ from madlib_keras_wrapper import *
 from model_arch_info import *
 
 from utilities.utilities import _assert
+from utilities.utilities import add_postfix
 from utilities.utilities import is_platform_pg
 from utilities.utilities import get_segments_per_host
+from utilities.utilities import get_seg_number
 from utilities.utilities import madlib_version
 from utilities.utilities import unique_string
 from utilities.validate_args import get_expr_type
@@ -56,7 +58,7 @@ class SD_STORE:
         del SD[SD_STORE.SEGMENT_MODEL]
         del SD[SD_STORE.SESS]
 
-def get_init_model_and_sess(SD, device_name, gpus_per_host, segments_per_host,
+def get_init_model_and_sess(SD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params):
     # If a live session is present, re-use it. Otherwise, recreate it.
     if SD_STORE.SESS in SD :
@@ -67,7 +69,7 @@ def get_init_model_and_sess(SD, device_name, gpus_per_host, segments_per_host,
         segment_model = SD[SD_STORE.SEGMENT_MODEL]
         K.set_session(sess)
     else:
-        sess = get_keras_session(device_name, gpus_per_host, segments_per_host)
+        sess = get_keras_session(device_name, gpu_count, segments_per_host)
         K.set_session(sess)
         segment_model = init_model(model_architecture, compile_params)
         SD_STORE.init_SD(SD, sess, segment_model)
@@ -76,9 +78,11 @@ def get_init_model_and_sess(SD, device_name, gpus_per_host, segments_per_host,
 @MinWarning("warning")
 def fit(schema_madlib, source_table, model, model_arch_table,
         model_id, compile_params, fit_params, num_iterations,
-        gpus_per_host=0, validation_table=None,
+        use_gpus, validation_table=None,
         metrics_compute_frequency=None, warm_start=False, name="",
         description="", **kwargs):
+
+    module_name = 'madlib_keras_fit'
     fit_params = "" if not fit_params else fit_params
     _assert(compile_params, "Compile parameters cannot be empty or NULL.")
 
@@ -90,20 +94,27 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     ind_shape_col = add_postfix(
         MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
 
+    segments_per_host = get_segments_per_host()
+    use_gpus = use_gpus if use_gpus else False
+    if use_gpus:
+        accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib, segments_per_host, module_name)
+    else:
+        accessible_gpus_for_seg = get_seg_number()*[0]
+
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table,
         model_id, mb_dep_var_col, mb_indep_var_col,
-        num_iterations, metrics_compute_frequency, warm_start)
+        num_iterations, metrics_compute_frequency, warm_start,
+        use_gpus, accessible_gpus_for_seg)
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
 
-    # The following two times must be recorded together.
-    metrics_elapsed_start_time = time.time()
-    start_training_time = datetime.datetime.now()
 
-    segments_per_host, gpus_per_host = get_segments_and_gpus(gpus_per_host)
     warm_start = bool(warm_start)
 
+    # The following two times must be recorded together.
+    metrics_elapsed_start_time = time.time()
+    start_training_time = datetime.datetime.now()
     #TODO add a unit test for this in a future PR
     # save the original value of the env variable so that we can reset it later.
     original_cuda_env = None
@@ -117,11 +128,13 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     input_shape = get_input_shape(model_arch)
     fit_validator.validate_input_shapes(input_shape)
     dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
+    gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
 
     serialized_weights = get_initial_weights(model, model_arch, model_weights,
-                                             warm_start, gpus_per_host)
+                                             warm_start, use_gpus, accessible_gpus_for_seg)
     # Compute total images on each segment
-    seg_ids_train, images_per_seg_train = get_image_count_per_seg_for_minibatched_data_from_db(source_table)
+    dist_key_mapping, images_per_seg_train = get_image_count_per_seg_for_minibatched_data_from_db(source_table)
+
 
     if validation_table:
         seg_ids_val, images_per_seg_val = get_image_count_per_seg_for_minibatched_data_from_db(validation_table)
@@ -143,10 +156,12 @@ def fit(schema_madlib, source_table, model, model_arch_table,
             {compile_params_to_pass}::TEXT,
             {fit_params_to_pass}::TEXT,
             {dist_key_col},
-            ARRAY{seg_ids_train},
-            ARRAY{images_per_seg_train},
-            {gpus_per_host},
+            ARRAY{dist_key_mapping},
+            {gp_segment_id_col},
             {segments_per_host},
+            ARRAY{images_per_seg_train},
+            {use_gpus}::BOOLEAN,
+            ARRAY{accessible_gpus_for_seg},
             $1,
             $2
         ) AS iteration_result
@@ -165,7 +180,8 @@ def fit(schema_madlib, source_table, model, model_arch_table,
         start_iteration = time.time()
         is_final_iteration = (i == num_iterations)
         serialized_weights = plpy.execute(run_training_iteration,
-                                        [serialized_weights, is_final_iteration])[0]['iteration_result']
+                                        [serialized_weights, is_final_iteration]
+                                        )[0]['iteration_result']
         end_iteration = time.time()
         info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
             end_iteration - start_iteration)
@@ -175,7 +191,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
             # Compute loss/accuracy for training data.
             compute_out = compute_loss_and_metrics(
                 schema_madlib, source_table, compile_params_to_pass, model_arch,
-                serialized_weights, gpus_per_host, segments_per_host, seg_ids_train,
+                serialized_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
                 images_per_seg_train, training_metrics, training_loss, i, is_final_iteration)
             metrics_iters.append(i)
             compute_time, compute_metrics, compute_loss = compute_out
@@ -191,7 +207,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
                 # Compute loss/accuracy for validation data.
                 val_compute_out = compute_loss_and_metrics(
                     schema_madlib, validation_table, compile_params_to_pass,
-                    model_arch, serialized_weights, gpus_per_host, segments_per_host,
+                    model_arch, serialized_weights, use_gpus, accessible_gpus_for_seg,
                     seg_ids_val, images_per_seg_val, validation_metrics,
                     validation_loss, i, is_final_iteration)
                 val_compute_time, val_compute_metrics, val_compute_loss = val_compute_out
@@ -307,14 +323,14 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     reset_cuda_env(original_cuda_env)
 
 def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
-                        gpus_per_host):
+                        use_gpus, accessible_gpus_for_seg):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
         table, if no weights are defined there, randomly initialize it using
         keras.
         We also need to set the cuda environment variable based on the platform.
-        1. For postgres, if user specifies gpus_per_host=0 which means they want
+        1. For postgres, if user specifies use_gpus=False which means they want
         to use CPU, then we have to set CUDA_VISIBLE_DEVICES to -1 to disable gpu.
         Otherwise model.get_weights() will use gpu if available.
 
@@ -326,9 +342,9 @@ def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
             @param warm_start: Boolean flag indicating warm start or not.
     """
     if is_platform_pg():
-        _ = get_device_name_and_set_cuda_env(gpus_per_host, None)
+        _ = get_device_name_and_set_cuda_env(use_gpus, accessible_gpus_for_seg[0], None)
     else:
-        _ = get_device_name_and_set_cuda_env(0, None)
+        _ = get_device_name_and_set_cuda_env(False, 0, None)
 
     if warm_start:
         serialized_weights = plpy.execute("""
@@ -374,8 +390,8 @@ def get_metrics_sql_string(metrics_list, is_metrics_specified):
     return metrics_final, metrics_all
 
 def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
-                             serialized_weights, gpus_per_host, segments_per_host,
-                             seg_ids, images_per_seg_val, metrics_list, loss_list,
+                             serialized_weights, use_gpus, accessible_gpus_for_seg,
+                             dist_key_mapping, images_per_seg_val, metrics_list, loss_list,
                              curr_iter, is_final_iteration):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
@@ -387,9 +403,9 @@ def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
                                                    compile_params,
                                                    model_arch,
                                                    serialized_weights,
-                                                   gpus_per_host,
-                                                   segments_per_host,
-                                                   seg_ids,
+                                                   use_gpus,
+                                                   accessible_gpus_for_seg,
+                                                   dist_key_mapping,
                                                    images_per_seg_val,
                                                    is_final_iteration)
     end_val = time.time()
@@ -439,9 +455,9 @@ def update_model(segment_model, prev_serialized_weights):
 
 def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                    independent_var_shape, model_architecture,
-                   compile_params, fit_params, current_seg_id, seg_ids,
-                   images_per_seg, gpus_per_host, segments_per_host,
-                   prev_serialized_weights, is_final_iteration=True,
+                   compile_params, fit_params, dist_key, dist_key_mapping,
+                   current_seg_id, segments_per_host, images_per_seg, use_gpus,
+                   accessible_gpus_for_seg, prev_serialized_weights, is_final_iteration=True,
                    is_multiple_model=False, **kwargs):
     """
     This transition function is common for madlib_keras_fit() and
@@ -462,11 +478,11 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     if not independent_var or not dependent_var:
         return state
     SD = kwargs['SD']
-    device_name = get_device_name_and_set_cuda_env(gpus_per_host,
-                                                   current_seg_id)
+    device_name = get_device_name_and_set_cuda_env(use_gpus, accessible_gpus_for_seg[current_seg_id], current_seg_id)
 
     segment_model, sess = get_init_model_and_sess(SD, device_name,
-                                                  gpus_per_host, segments_per_host,
+                                                  accessible_gpus_for_seg[current_seg_id],
+                                                  segments_per_host,
                                                   model_architecture, compile_params)
     agg_image_count = madlib_keras_serializer.get_image_count_from_state(state)
     if not state:
@@ -485,9 +501,8 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     # Aggregating number of images, loss and accuracy
     agg_image_count += image_count
     updated_weights = segment_model.get_weights()
-    total_images = get_image_count_per_seg_from_array(current_seg_id, seg_ids,
+    total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
-
     if agg_image_count == total_images:
         # For madlib_keras_fit_multiple_model(), we don't need to update weights
         # with the total no of images as there is no merge function for it.
@@ -549,20 +564,8 @@ def fit_final_multiple_model(state, **kwargs):
 
     return madlib_keras_serializer.serialize_nd_weights(weights)
 
-def get_segments_and_gpus(gpus_per_host):
-    gpus_per_host = 0 if gpus_per_host is None else gpus_per_host
-    segments_per_host = get_segments_per_host()
-
-    if 0 < gpus_per_host < segments_per_host:
-        plpy.warning('The number of GPUs per segment host is less than the number of '
-                     'segments per segment host. When different segments share the same GPU, '
-                     'this may fail in some scenarios. The current recommended configuration '
-                     'is to have 1 GPU available per segment.')
-
-    return segments_per_host, gpus_per_host
-
 def evaluate(schema_madlib, model_table, test_table, output_table,
-             gpus_per_host, mst_key, **kwargs):
+             use_gpus, mst_key, **kwargs):
 
     module_name = 'madlib_keras_evaluate'
     is_mult_model = mst_key is not None
@@ -578,7 +581,12 @@ def evaluate(schema_madlib, model_table, test_table, output_table,
         model_summary_table = create_summary_view(module_name, model_table, mst_key)
 
     validate_evaluate(module_name, model_table, model_summary_table, test_table, test_summary_table, output_table, is_mult_model)
-    segments_per_host, gpus_per_host = get_segments_and_gpus(gpus_per_host)
+
+    segments_per_host = get_segments_per_host()
+    if use_gpus:
+        accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib, segments_per_host, module_name)
+    else:
+        accessible_gpus_for_seg = get_seg_number()*[0]
 
     model_weights_query = "SELECT model_weights, model_arch FROM {0} {1}".format(
         model_table, mult_where_clause)
@@ -597,13 +605,13 @@ def evaluate(schema_madlib, model_table, test_table, output_table,
     metrics_type = res['metrics_type']
     compile_params = "$madlib$" + res['compile_params'] + "$madlib$"
 
-    seg_ids, images_per_seg = get_image_count_per_seg_for_minibatched_data_from_db(test_table)
+    dist_key_mapping, images_per_seg = get_image_count_per_seg_for_minibatched_data_from_db(test_table)
 
     loss, metric = \
         get_loss_metric_from_keras_eval(
             schema_madlib, test_table, compile_params, model_arch,
-            model_weights, gpus_per_host, segments_per_host,
-            seg_ids, images_per_seg)
+            model_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
+            images_per_seg)
 
     if not metrics_type:
         metrics_type = None
@@ -641,11 +649,13 @@ def validate_evaluate(module_name, model_table, model_summary_table, test_table,
                                      MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL)
 
 def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
-                                    model_arch, serialized_weights, gpus_per_host,
-                                    segments_per_host, seg_ids, images_per_seg,
+                                    model_arch, serialized_weights, use_gpus,
+                                    accessible_gpus_for_seg, dist_key_mapping, images_per_seg,
                                     is_final_iteration=True):
 
     dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
+    gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
+    segments_per_host = get_segments_per_host()
 
     mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
     mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
@@ -658,6 +668,7 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
     This function will call the internal keras evaluate function to get the loss
     and accuracy of each tuple which then gets averaged to get the final result.
     """
+    use_gpus = use_gpus if use_gpus else False
     evaluate_query = plpy.prepare("""
         select ({schema_madlib}.internal_keras_evaluate(
                                             {mb_dep_var_col},
@@ -668,14 +679,16 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                             $1,
                                             {compile_params},
                                             {dist_key_col},
-                                            ARRAY{seg_ids},
-                                            ARRAY{images_per_seg},
-                                            {gpus_per_host},
+                                            ARRAY{dist_key_mapping},
+                                            {gp_segment_id_col},
                                             {segments_per_host},
+                                            ARRAY{images_per_seg},
+                                            {use_gpus}::BOOLEAN,
+                                            ARRAY{accessible_gpus_for_seg},
                                             {is_final_iteration}
                                             )) as loss_metric
         from {table}
-    """.format(**locals()), ["bytea"])
+        """.format(**locals()), ["bytea"])
     res = plpy.execute(evaluate_query, [serialized_weights])
     loss_metric = res[0]['loss_metric']
     return loss_metric
@@ -684,11 +697,12 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
 def internal_keras_eval_transition(state, dependent_var, independent_var,
                                    dependent_var_shape, independent_var_shape,
                                    model_architecture, serialized_weights, compile_params,
-                                   current_seg_id, seg_ids, images_per_seg,
-                                   gpus_per_host, segments_per_host,
+                                   dist_key, dist_key_mapping, current_seg_id,
+                                   segments_per_host, images_per_seg,
+                                   use_gpus, accessible_gpus_for_seg,
                                    is_final_iteration, **kwargs):
     SD = kwargs['SD']
-    device_name = get_device_name_and_set_cuda_env(gpus_per_host, current_seg_id)
+    device_name = get_device_name_and_set_cuda_env(use_gpus, accessible_gpus_for_seg[current_seg_id], current_seg_id)
     agg_loss, agg_metric, agg_image_count = state
 
     # This transition function is common to evaluate as well as the fit functions
@@ -702,8 +716,10 @@ def internal_keras_eval_transition(state, dependent_var, independent_var,
     #   for the last buffer of last iteration
     #  if is_final_iteration is false, we can clear the
 
-    segment_model, sess = get_init_model_and_sess(SD, device_name, gpus_per_host,
-                                                  segments_per_host, model_architecture,
+    segment_model, sess = get_init_model_and_sess(SD, device_name,
+                                                  accessible_gpus_for_seg[current_seg_id],
+                                                  segments_per_host,
+                                                  model_architecture,
                                                   compile_params)
     if not agg_image_count:
         # These should already be 0, but just in case make sure
@@ -731,7 +747,7 @@ def internal_keras_eval_transition(state, dependent_var, independent_var,
     agg_loss += (image_count * loss)
     agg_metric += (image_count * metric)
 
-    total_images = get_image_count_per_seg_from_array(current_seg_id, seg_ids,
+    total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
 
     if agg_image_count == total_images and is_final_iteration:
@@ -823,8 +839,7 @@ For more details on function usage:
     fit_params,                 --  Parameters passed to the fit method
                                     of the Keras model class
     num_iterations,             --  Number of iterations to train.
-    gpus_per_host,              --  Number of GPUs per segment host to
-                                    be used for training
+    use_gpus,                   --  Flag to enable GPU support
     validation_table,           --  Name of the table containing
                                     the validation dataset
     metrics_compute_frequency,  --  Frequency to compute per-iteration
@@ -886,8 +901,7 @@ For more details on function usage:
     model_table,    --  Name of the table containing the model
     test_table,     --  Name of the table containing the evaluation dataset
     output_table,   --  Name of the output table
-    gpus_per_host,  --  Number of GPUs per segment host to
-                        be used for training
+    use_gpus,       --  Flag to enable GPU support
     mst_key         --  Identifier for the desired model out of multimodel
                         training output
     )
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index e081760..be5449a 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -97,7 +97,7 @@ madlib_keras_fit(
     compile_params,
     fit_params,
     num_iterations,
-    gpus_per_host,
+    use_gpus,
     validation_table,
     metrics_compute_frequency,
     warm_start,
@@ -170,22 +170,17 @@ madlib_keras_fit(
   <DD>INTEGER.  Number of iterations to train.
   </DD>
 
-  <DT>gpus_per_host (optional)</DT>
-  <DD>INTEGER, default: 0 (i.e., CPU).
-    Number of GPUs per segment host to be used
-    for training the neural network.
-    For example, if you specify 4 for this parameter
-    and your database cluster is set up to have 4
-    segments per segment host, it means that each
-    segment will have a dedicated GPU.
-    A value of 0 means that CPUs, not GPUs, will
-    be used for training.
+  <DT>use_gpus (optional)</DT>
+  <DD>BOOLEAN, default: FALSE (i.e., CPU).
+    Flag to enable GPU support for training neural network.
+    The number of GPUs to use is determined by the parameters
+    passed to the preprocessor.
 
     @note
     We have seen some memory related issues when segments
     share GPU resources.
-    For example, if you specify 1 for this parameter
-    and your database cluster is set up to have 4
+    For example, if you provide 1 GPU and your
+    database cluster is set up to have 4
     segments per segment host, it means that all 4
     segments on a segment host will share the same
     GPU. The current recommended
@@ -462,7 +457,7 @@ madlib_keras_evaluate(
     model_table,
     test_table,
     output_table,
-    gpus_per_host,
+    use_gpus,
     mst_key
     )
 </pre>
@@ -503,28 +498,25 @@ madlib_keras_evaluate(
         <td>Type of metric used that was used in the training step.</td>
       </tr>
 
-  <DT>gpus_per_host (optional)</DT>
-  <DD>INTEGER, default: 0 (i.e., CPU).
-    Number of GPUs per segment host to be used
-    for training the neural network.
-    For example, if you specify 4 for this parameter
-    and your database cluster is set up to have 4
-    segments per segment host, it means that each
-    segment will have a dedicated GPU.
-    A value of 0 means that CPUs, not GPUs, will
-    be used for training.
+
+  <DT>use_gpus (optional)</DT>
+  <DD>BOOLEAN, default: FALSE (i.e., CPU).
+    Flag to enable GPU support for evaluating neural network.
+    The number of GPUs to use is determined by the parameters
+    passed to the preprocessor.
 
     @note
     We have seen some memory related issues when segments
     share GPU resources.
-    For example, if you specify 1 for this parameter
-    and your database cluster is set up to have 4
+    For example, if you provide 1 GPU and your
+    database cluster is set up to have 4
     segments per segment host, it means that all 4
     segments on a segment host will share the same
     GPU. The current recommended
     configuration is 1 GPU per segment.
   </DD>
 
+
   <DT>mst_key (optional)</DT>
   <DD>INTEGER, default: NULL. To be filled.
   </DD>
@@ -724,18 +716,24 @@ madlib_keras_predict_byom(
 
   <DT>gpus_per_host (optional)</DT>
   <DD>INTEGER, default: 0 (i.e., CPU).
-    Number of GPUs per segment host to be used for training the neural network.
-    For example, if you specify 4 for this parameter and your database cluster
-    is set up to have 4 segments per segment host, it means that each segment
-    will have a dedicated GPU. A value of 0 means that CPUs, not GPUs, will
+    Number of GPUs per segment host to be used
+    for training the neural network.
+    For example, if you specify 4 for this parameter
+    and your database cluster is set up to have 4
+    segments per segment host, it means that each
+    segment will have a dedicated GPU.
+    A value of 0 means that CPUs, not GPUs, will
     be used for training.
 
     @note
-    We have seen some memory related issues when segments share GPU resources.
-    For example, if you specify 1 for this parameter and your database cluster
-    is set up to have 4 segments per segment host, it means that all 4 segments
-     on a segment host will share the same GPU. The current recommended
-     configuration is 1 GPU per segment.
+    We have seen some memory related issues when segments
+    share GPU resources.
+    For example, if you specify 1 for this parameter
+    and your database cluster is set up to have 4
+    segments per segment host, it means that all 4
+    segments on a segment host will share the same
+    GPU. The current recommended
+    configuration is 1 GPU per segment.
   </DD>
 
   <DT>class_values (optional)</DT>
@@ -1610,7 +1608,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     compile_params          VARCHAR,
     fit_params              VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
@@ -1631,7 +1629,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     compile_params          VARCHAR,
     fit_params              VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
@@ -1649,7 +1647,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     compile_params          VARCHAR,
     fit_params              VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN
@@ -1667,7 +1665,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     compile_params          VARCHAR,
     fit_params              VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER
 ) RETURNS VOID AS $$
@@ -1684,7 +1682,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     compile_params          VARCHAR,
     fit_params              VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR
 ) RETURNS VOID AS $$
     SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, NULL, NULL, NULL, NULL);
@@ -1699,7 +1697,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     compile_params          VARCHAR,
     fit_params              VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER
+    use_gpus                BOOLEAN
 ) RETURNS VOID AS $$
     SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, NULL, NULL, NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
@@ -1714,26 +1712,28 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     fit_params              VARCHAR,
     num_iterations          INTEGER
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, 0, NULL, NULL, NULL, NULL, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, FALSE, NULL, NULL, NULL, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
-    state                      BYTEA,
-    dependent_var              BYTEA,
-    independent_var            BYTEA,
-    dependent_var_shape        INTEGER[],
-    independent_var_shape      INTEGER[],
-    model_architecture         TEXT,
-    compile_params             TEXT,
-    fit_params                 TEXT,
-    current_seg_id             INTEGER,
-    seg_ids                    INTEGER[],
-    images_per_seg             INTEGER[],
-    gpus_per_host              INTEGER,
-    segments_per_host          INTEGER,
-    prev_serialized_weights    BYTEA,
-    is_final_iteration         BOOLEAN
+    state                       BYTEA,
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    use_gpus                    BOOLEAN,
+    accessible_gpus_for_seg                INTEGER[],
+    prev_serialized_weights     BYTEA,
+    is_final_iteration          BOOLEAN
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_transition(**globals())
@@ -1767,9 +1767,11 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step(
     TEXT,
     INTEGER,
     INTEGER[],
-    INTEGER[],
     INTEGER,
     INTEGER,
+    INTEGER[],
+    BOOLEAN,
+    INTEGER[],
     BYTEA,
     BOOLEAN);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
@@ -1780,11 +1782,13 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* model_architecture */     TEXT,
     /* compile_params */         TEXT,
     /* fit_params */             TEXT,
+    /* dist_key */               INTEGER,
+    /* dist_key_mapping */       INTEGER[],
     /* current_seg_id */         INTEGER,
-    /* seg_ids*/                 INTEGER[],
-    /* images_per_seg*/          INTEGER[],
-    /* gpus_per_host  */         INTEGER,
-    /* segments_per_host  */     INTEGER,
+    /* segments_per_host */      INTEGER,
+    /* images_per_seg */         INTEGER[],
+    /* use_gpus  */              BOOLEAN,
+    /* segments_per_host  */     INTEGER[],
     /* serialized_weights */     BYTEA,
     /* is_final_iteration */     BOOLEAN
 )(
@@ -1951,7 +1955,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_evaluate(
     model_table             VARCHAR,
     test_table              VARCHAR,
     output_table            VARCHAR,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     mst_key                 INTEGER
 ) RETURNS VOID AS $$
     PythonFunction(`deep_learning', `madlib_keras', `evaluate')
@@ -1962,7 +1966,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_evaluate(
     model_table             VARCHAR,
     test_table              VARCHAR,
     output_table            VARCHAR,
-    gpus_per_host           INTEGER
+    use_gpus                BOOLEAN
 ) RETURNS VOID AS $$
   SELECT MADLIB_SCHEMA.madlib_keras_evaluate($1, $2, $3, $4, NULL);
 $$ LANGUAGE sql VOLATILE
@@ -1986,11 +1990,13 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_eval_transition(
     model_architecture                 TEXT,
     serialized_weights                 BYTEA,
     compile_params                     TEXT,
+    dist_key                           INTEGER,
+    dist_key_mapping                   INTEGER[],
     current_seg_id                     INTEGER,
-    seg_ids                            INTEGER[],
-    images_per_seg                     INTEGER[],
-    gpus_per_host                      INTEGER,
     segments_per_host                  INTEGER,
+    images_per_seg                     INTEGER[],
+    use_gpus                           BOOLEAN,
+    accessible_gpus_for_seg                       INTEGER[],
     is_final_iteration                 BOOLEAN
 ) RETURNS REAL[3] AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
@@ -2025,25 +2031,29 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.internal_keras_evaluate(
                                        TEXT,
                                        INTEGER,
                                        INTEGER[],
-                                       INTEGER[],
                                        INTEGER,
                                        INTEGER,
+                                       INTEGER[],
+                                       BOOLEAN,
+                                       INTEGER[],
                                        BOOLEAN);
 
 CREATE AGGREGATE MADLIB_SCHEMA.internal_keras_evaluate(
-    /* dependent_var */                BYTEA,
-    /* independent_var */              BYTEA,
-    /* dependent_var_shape */          INTEGER[],
-    /* independent_var_shape */        INTEGER[],
-    /* model_architecture */           TEXT,
-    /* model_weights */                BYTEA,
-    /* compile_params */               TEXT,
-    /* current_seg_id */               INTEGER,
-    /* seg_ids */                      INTEGER[],
-    /* images_per_seg*/                INTEGER[],
-    /* gpus_per_host */                INTEGER,
-    /* segments_per_host */            INTEGER,
-    /* is_final_iteration */           BOOLEAN
+    /* dependent_var */             BYTEA,
+    /* independent_var */           BYTEA,
+    /* dependent_var_shape */       INTEGER[],
+    /* independent_var_shape */     INTEGER[],
+    /* model_architecture */        TEXT,
+    /* model_weights */             BYTEA,
+    /* compile_params */            TEXT,
+    /* dist_key */                  INTEGER,
+    /* dist_key_mapping */          INTEGER[],
+    /* current_seg_id */            INTEGER,
+    /* segments_per_host */         INTEGER,
+    /* images_per_seg*/             INTEGER[],
+    /* use_gpus */                  BOOLEAN,
+    /* accessible_gpus_for_seg */              INTEGER[],
+    /* is_final_iteration */        BOOLEAN
 )(
     STYPE=REAL[3],
     INITCOND='{0,0,0}',
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 883ed22..f889980 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -25,7 +25,6 @@ from keras.models import *
 from madlib_keras import compute_loss_and_metrics
 from madlib_keras import get_initial_weights
 from madlib_keras import get_model_arch_weights
-from madlib_keras import get_segments_and_gpus
 from madlib_keras import get_source_summary_table_dict
 from madlib_keras import should_compute_metrics_this_iter
 from madlib_keras_helper import *
@@ -40,6 +39,8 @@ from utilities.utilities import add_postfix
 from utilities.utilities import rotate
 from utilities.utilities import madlib_version
 from utilities.utilities import is_platform_pg
+from utilities.utilities import get_seg_number
+from utilities.utilities import get_segments_per_host
 import json
 from collections import defaultdict
 import random
@@ -74,7 +75,7 @@ Note that this function is disabled for Postgres.
 class FitMultipleModel():
     def __init__(self, schema_madlib, source_table, model_output_table,
                  model_selection_table, num_iterations,
-                 gpus_per_host=0, validation_table=None,
+                 use_gpus=False, validation_table=None,
                  metrics_compute_frequency=None, warm_start=False, name="",
                  description="", **kwargs):
         # set the random seed for visit order/scheduling
@@ -112,12 +113,21 @@ class FitMultipleModel():
         self.info_str = ""
         self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
+        self.use_gpus = use_gpus
+        self.segments_per_host = get_segments_per_host()
+        if self.use_gpus:
+            self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(self.schema_madlib,
+                self.segments_per_host, self.module_name)
+        else:
+            self.accessible_gpus_for_seg = get_seg_number()*[0]
+
         self.fit_validator_train = FitMultipleInputValidator(
             self.source_table, self.validation_table, self.model_output_table,
             self.model_selection_table, self.model_selection_summary_table,
             mb_dep_var_col, mb_indep_var_col, self.num_iterations,
             self.model_info_table, self.mst_key_col, self.model_arch_table_col,
-            self.metrics_compute_frequency, warm_start)
+            self.metrics_compute_frequency, warm_start, self.use_gpus,
+            self.accessible_gpus_for_seg)
         if self.metrics_compute_frequency is None:
             self.metrics_compute_frequency = num_iterations
         self.warm_start = bool(warm_start)
@@ -129,7 +139,7 @@ class FitMultipleModel():
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
             original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
-        self.seg_ids_train, self.images_per_seg_train = \
+        self.dist_key_mapping, self.images_per_seg_train = \
             get_image_count_per_seg_for_minibatched_data_from_db(
                 self.source_table)
 
@@ -137,7 +147,7 @@ class FitMultipleModel():
             self.valid_mst_metric_eval_time = defaultdict(list)
             self.valid_mst_loss = defaultdict(list)
             self.valid_mst_metric = defaultdict(list)
-            self.seg_ids_valid, self.images_per_seg_valid = \
+            self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
         self.mst_weights_tbl = unique_string(desp='mst_weights')
@@ -151,8 +161,7 @@ class FitMultipleModel():
             self.msts_for_schedule = self.msts
         random.shuffle(self.msts_for_schedule)
         self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
-            gpus_per_host)
+        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
         if not self.warm_start:
             self.create_model_output_table()
         self.weights_to_update_tbl = unique_string(desp='weights_to_update')
@@ -197,13 +206,13 @@ class FitMultipleModel():
             mst_metric_eval_time = self.train_mst_metric_eval_time
             mst_loss = self.train_mst_loss
             mst_metric = self.train_mst_metric
-            seg_ids = self.seg_ids_train
+            seg_ids = self.dist_key_mapping
             images_per_seg = self.images_per_seg_train
         else:
             mst_metric_eval_time = self.valid_mst_metric_eval_time
             mst_loss = self.valid_mst_loss
             mst_metric = self.valid_mst_metric
-            seg_ids = self.seg_ids_valid
+            seg_ids = self.dist_key_mapping_valid
             images_per_seg = self.images_per_seg_valid
         for mst in self.msts:
             weights = query_weights(self.model_output_table, self.model_weights_col,
@@ -214,10 +223,11 @@ class FitMultipleModel():
                     mst[self.compile_params_col]),
                 model_arch,
                 weights,
-                self.gpus_per_host,
-                self.segments_per_host,
+                self.use_gpus,
+                self.accessible_gpus_for_seg,
                 seg_ids,
-                images_per_seg, [], [], epoch, True)
+                self.images_per_seg_train,
+                [], [], epoch, True)
             mst_metric_eval_time[mst[self.mst_key_col]] \
                 .append(metric_eval_time)
             mst_loss[mst[self.mst_key_col]].append(loss)
@@ -298,7 +308,9 @@ class FitMultipleModel():
                                                      model_arch,
                                                      model_weights,
                                                      False,
-                                                     self.gpus_per_host)
+                                                     self.use_gpus,
+                                                     self.accessible_gpus_for_seg
+                                                     )
             model = model_from_json(model_arch)
             serialized_state = model_weights if model_weights else \
                 madlib_keras_serializer.serialize_nd_weights(model.get_weights())
@@ -468,34 +480,38 @@ class FitMultipleModel():
         """.format(dist_key_col=dist_key_col,
                    **locals())
         plpy.execute(mst_weights_query)
+        use_gpus = self.use_gpus if self.use_gpus else False
         uda_query = """
             CREATE TABLE {self.weights_to_update_tbl} AS
             SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                                                      {mb_indep_var_col},
-                                                      {self.dep_shape_col},
-                                                      {self.ind_shape_col},
-                                                      {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                                                      {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                                                      {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                                                      src.{dist_key_col},
-                                                      ARRAY{self.seg_ids_train},
-                                                      ARRAY{self.images_per_seg_train},
-                                                      {self.gpus_per_host},
-                                                      {self.segments_per_host},
-                                                      {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                                                      {is_final_iteration}::BOOLEAN
-                                                      )::BYTEA AS {self.model_weights_col},
+                {mb_indep_var_col},
+                {self.dep_shape_col},
+                {self.ind_shape_col},
+                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
+                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
+                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
+                src.{dist_key_col},
+                ARRAY{self.dist_key_mapping},
+                src.{self.gp_segment_id_col},
+                {self.segments_per_host},
+                ARRAY{self.images_per_seg_train},
+                {use_gpus}::BOOLEAN,
+                ARRAY{self.accessible_gpus_for_seg},
+                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
+                {is_final_iteration}::BOOLEAN
+                )::BYTEA AS {self.model_weights_col},
                 {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
                 ,src.{dist_key_col} AS {dist_key_col}
-                FROM {self.source_table} src JOIN {self.mst_weights_tbl}
-                    USING ({dist_key_col})
-                WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
-                GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
+            FROM {self.source_table} src JOIN {self.mst_weights_tbl}
+                USING ({dist_key_col})
+            WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
+            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
             DISTRIBUTED BY({dist_key_col})
             """.format(mb_dep_var_col=mb_dep_var_col,
                        mb_indep_var_col=mb_indep_var_col,
                        is_final_iteration=True,
                        dist_key_col=dist_key_col,
+                       use_gpus=use_gpus,
                        self=self
                        )
         plpy.execute(uda_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 433535d..d1c261a 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -34,7 +34,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
@@ -52,7 +52,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
@@ -67,7 +67,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN
@@ -81,7 +81,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER
 ) RETURNS VOID AS $$
@@ -94,7 +94,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER,
+    use_gpus                BOOLEAN,
     validation_table        VARCHAR
 ) RETURNS VOID AS $$
     SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, NULL, FALSE, NULL, NULL);
@@ -106,7 +106,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
-    gpus_per_host           INTEGER
+    use_gpus                BOOLEAN
 ) RETURNS VOID AS $$
     SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, NULL, NULL, FALSE, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
@@ -121,11 +121,13 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
     model_architecture         TEXT,
     compile_params             TEXT,
     fit_params                 TEXT,
+    dist_key                   INTEGER,
+    dist_key_mapping           INTEGER[],
     current_seg_id             INTEGER,
-    seg_ids                    INTEGER[],
-    images_per_seg             INTEGER[],
-    gpus_per_host              INTEGER,
     segments_per_host          INTEGER,
+    images_per_seg             INTEGER[],
+    use_gpus                   BOOLEAN,
+    accessible_gpus_for_seg               INTEGER[],
     prev_serialized_weights    BYTEA,
     is_final_iteration         BOOLEAN
 ) RETURNS BYTEA AS $$
@@ -146,33 +148,37 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step_multiple_model(
     BYTEA,
     BYTEA,
-    TEXT,
-    TEXT,
+    INTEGER[],
+    INTEGER[],
     TEXT,
     TEXT,
     TEXT,
     INTEGER,
     INTEGER[],
-    INTEGER[],
     INTEGER,
     INTEGER,
+    INTEGER[],
+    BOOLEAN,
+    INTEGER[],
     BYTEA,
     BOOLEAN);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
-    /* dep_var */                BYTEA,
-    /* ind_var */                BYTEA,
-    /* dep_var_shape */          INTEGER[],
-    /* ind_var_shape */          INTEGER[],
-    /* model_architecture */     TEXT,
-    /* compile_params */         TEXT,
-    /* fit_params */             TEXT,
-    /* current_seg_id */         INTEGER,
-    /* seg_ids*/                 INTEGER[],
-    /* images_per_seg*/          INTEGER[],
-    /* gpus_per_host  */         INTEGER,
-    /* segments_per_host  */     INTEGER,
-    /* serialized_weights */     BYTEA,
-    /* is_final_iteration */     BOOLEAN
+    /* dependent_var */              BYTEA,
+    /* independent_var */            BYTEA,
+    /* dependent_var_shape */        INTEGER[],
+    /* independent_var_shape */      INTEGER[],
+    /* model_architecture */         TEXT,
+    /* compile_params */             TEXT,
+    /* fit_params */                 TEXT,
+    /* dist_key */                   INTEGER,
+    /* dist_key_mapping */           INTEGER[],
+    /* current_seg_id */             INTEGER,
+    /* segments_per_host */          INTEGER,
+    /* images_per_seg */             INTEGER[],
+    /* use_gpus */                   BOOLEAN,
+    /* accessible_gpus_for_seg */               INTEGER[],
+    /* prev_serialized_weights */    BYTEA,
+    /* is_final_iteration */         BOOLEAN
 )(
     STYPE=BYTEA,
     SFUNC=MADLIB_SCHEMA.fit_transition_multiple_model,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 70761d5..3dcc572 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -23,6 +23,7 @@ from utilities.utilities import add_postfix
 from utilities.utilities import unique_string
 from utilities.utilities import is_platform_pg
 from utilities.validate_args import table_exists
+from madlib_keras_gpu_info import GPUInfoFunctions
 import plpy
 
 
@@ -51,6 +52,8 @@ SMALLINT_SQL_TYPE = 'SMALLINT'
 
 DEFAULT_NORMALIZING_CONST = 1.0
 DEFAULT_GPU_CONFIG = 'all_segments'
+GP_SEGMENT_ID_COLNAME = "gp_segment_id"
+INTERNAL_GPU_CONFIG = '__internal_gpu_config__'
 
 #####################################################################
 
@@ -105,7 +108,7 @@ def strip_trailing_nulls_from_class_values(class_values):
         class_values = class_values[:num_of_valid_class_values]
     return class_values
 
-def get_image_count_per_seg_from_array(current_seg_id, seg_ids, images_per_seg):
+def get_image_count_per_seg_from_array_predict(current_seg_id, seg_ids, images_per_seg):
     """
     Get the image count from the array containing all the images
     per segment. Based on the platform, we find the index of the current segment.
@@ -117,6 +120,18 @@ def get_image_count_per_seg_from_array(current_seg_id, seg_ids, images_per_seg):
         total_images = images_per_seg[seg_ids.index(current_seg_id)]
     return total_images
 
+def get_image_count_per_seg_from_array(current_seg_id, images_per_seg):
+    """
+    Get the image count from the array containing all the images
+    per segment. Based on the platform, we find the index of the current segment.
+    This function is only called from inside the transition function.
+    """
+    if is_platform_pg():
+        total_images = images_per_seg[0]
+    else:
+        total_images = images_per_seg[current_seg_id]
+    return total_images
+
 def get_image_count_per_seg_for_minibatched_data_from_db(table_name):
     """
     Query the given minibatch formatted table and return the total rows per segment.
@@ -266,3 +281,47 @@ def create_summary_view(module_name, model_table, mst_key):
         WHERE mst_key = {mst_key}
         """.format(**locals()))
     return tmp_view_summary
+
+def get_accessible_gpus_for_seg(schema_madlib, segments_per_host, module_name):
+
+    if is_platform_pg():
+        gpus = GPUInfoFunctions.get_gpu_info_from_tensorflow()
+        return [len(gpus)]
+    else:
+        gpu_table_name = unique_string(desp = 'gpu_table')
+        gpu_table_query = """
+            SELECT {schema_madlib}.gpu_configuration('{gpu_table_name}')
+        """.format(**locals())
+        plpy.execute(gpu_table_query)
+        gpu_query = """
+            SELECT hostname, count(*) AS count FROM {gpu_table_name} GROUP BY hostname
+            """.format(**locals())
+        gpu_query_result = plpy.execute(gpu_query)
+
+        if not gpu_query_result:
+           plpy.error("{0} error: No GPUs configured on hosts.".format(module_name))
+
+        host_dict = {}
+        for i in gpu_query_result:
+            host_dict[i['hostname']] = int(i['count'])
+
+        seg_query = """
+            SELECT hostname, content AS segment_id
+            FROM gp_segment_configuration
+            WHERE content != -1 AND role = 'p'
+        """
+        seg_query_result = plpy.execute(seg_query)
+
+        accessible_gpus_for_seg = [0] * len(seg_query_result)
+        warning_flag = True
+        for i in seg_query_result:
+            if i['hostname'] in host_dict.keys():
+                accessible_gpus_for_seg[i['segment_id']] = host_dict[i['hostname']]
+            if 0 < accessible_gpus_for_seg[i['segment_id']] < segments_per_host and warning_flag:
+                plpy.warning(
+                    'The number of GPUs per segment host is less than the number of '
+                    'segments per segment host. When different segments share the '
+                    'same GPU, this may fail in some scenarios. The current '
+                    'recommended configuration is to have 1 GPU available per segment.')
+                warning_flag = False
+        return accessible_gpus_for_seg
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 8c756ee..cf06e45 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -250,7 +250,7 @@ def internal_keras_predict(independent_var, model_architecture, model_weights,
     model_key = 'segment_model_predict'
     row_count_key = 'row_count'
     try:
-        device_name = get_device_name_and_set_cuda_env(gpus_per_host,
+        device_name = get_device_name_and_set_cuda_env_predict(gpus_per_host,
                                                        current_seg_id)
         if model_key not in SD:
             set_keras_session(device_name, gpus_per_host, segments_per_host)
@@ -284,7 +284,7 @@ def internal_keras_predict(independent_var, model_architecture, model_weights,
             # and not mini-batched, this list contains exactly one list in it,
             # so return back the first list in probs.
             result = probs[0]
-        total_images = get_image_count_per_seg_from_array(current_seg_id, seg_ids,
+        total_images = get_image_count_per_seg_from_array_predict(current_seg_id, seg_ids,
                                                           images_per_seg)
 
         if SD[row_count_key] == total_images:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 49b8934..ad14087 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -30,6 +30,8 @@ from madlib_keras_helper import MODEL_WEIGHTS_COLNAME
 from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
 from madlib_keras_helper import METRIC_TYPE_COLNAME
+from madlib_keras_helper import INTERNAL_GPU_CONFIG
+from madlib_keras_helper import DEFAULT_GPU_CONFIG
 from madlib_keras_helper import query_model_configs
 
 from utilities.minibatch_validation import validate_bytea_var_for_minibatch
@@ -225,12 +227,27 @@ class InputValidator:
             "summary table ('{1}'). The expected columns are {2}.".format(
                 module_name, model_summary_table, cols_to_check_for))
 
+    @staticmethod
+    def _validate_gpu_config(module_name, source_table, accessible_gpus_for_seg):
+
+        summary_table = add_postfix(source_table, "_summary")
+        gpu_config = plpy.execute(
+            "SELECT {0} FROM {1}".format(INTERNAL_GPU_CONFIG, summary_table)
+            )[0][INTERNAL_GPU_CONFIG]
+        if gpu_config == DEFAULT_GPU_CONFIG:
+            _assert(0 not in accessible_gpus_for_seg,
+                "{0} error: Host(s) are missing gpus.".format(module_name))
+        else:
+            for i in gpu_config:
+                _assert(accessible_gpus_for_seg[i] != 0,
+                    "{0} error: Segment {1} does not have gpu".format(module_name, i))
 
 class FitCommonValidator(object):
     def __init__(self, source_table, validation_table, output_model_table,
                  model_arch_table, model_id, dependent_varname,
                  independent_varname, num_iterations,
-                 metrics_compute_frequency, warm_start, module_name):
+                 metrics_compute_frequency, warm_start,
+                 use_gpus, accessible_gpus_for_seg, module_name):
         self.source_table = source_table
         self.validation_table = validation_table
         self.output_model_table = output_model_table
@@ -250,8 +267,12 @@ class FitCommonValidator(object):
         if self.output_model_table:
             self.output_summary_model_table = add_postfix(
                 self.output_model_table, "_summary")
+        self.accessible_gpus_for_seg = accessible_gpus_for_seg
         self.module_name = module_name
         self._validate_common_args()
+        if use_gpus:
+            InputValidator._validate_gpu_config(self.module_name,
+                self.source_table, self.accessible_gpus_for_seg)
 
     def _validate_common_args(self):
         _assert(self.num_iterations > 0,
@@ -360,7 +381,8 @@ class FitInputValidator(FitCommonValidator):
     def __init__(self, source_table, validation_table, output_model_table,
                  model_arch_table, model_id, dependent_varname,
                  independent_varname, num_iterations,
-                 metrics_compute_frequency, warm_start):
+                 metrics_compute_frequency, warm_start,
+                 use_gpus, accessible_gpus_for_seg):
 
         self.module_name = 'madlib_keras_fit'
         super(FitInputValidator, self).__init__(source_table,
@@ -373,6 +395,8 @@ class FitInputValidator(FitCommonValidator):
                                                 num_iterations,
                                                 metrics_compute_frequency,
                                                 warm_start,
+                                                use_gpus,
+                                                accessible_gpus_for_seg,
                                                 self.module_name)
         InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
             self.model_id)
@@ -381,7 +405,8 @@ class FitMultipleInputValidator(FitCommonValidator):
     def __init__(self, source_table, validation_table, output_model_table,
                  model_selection_table, model_selection_summary_table, dependent_varname,
                  independent_varname, num_iterations, model_info_table, mst_key_col,
-                 model_arch_table_col, metrics_compute_frequency, warm_start):
+                 model_arch_table_col, metrics_compute_frequency, warm_start,
+                 use_gpus, accessible_gpus_for_seg):
 
         self.module_name = 'madlib_keras_fit_multiple'
 
@@ -408,6 +433,8 @@ class FitMultipleInputValidator(FitCommonValidator):
                                                         num_iterations,
                                                         metrics_compute_frequency,
                                                         warm_start,
+                                                        use_gpus,
+                                                        accessible_gpus_for_seg,
                                                         self.module_name)
 
         if warm_start:
@@ -465,8 +492,8 @@ class MstLoaderInputValidator():
                 res = parse_and_validate_fit_params(fit_params)
             except Exception as e:
                 plpy.error(
-                    """Fit param check failed for: {} \n
-                    {}
+                    """Fit param check failed for: {0} \n
+                    {1}
                     """.format(fit_params, str(e)))
         if not self.compile_params_list:
             plpy.error( "compile_params_list cannot be NULL")
@@ -475,8 +502,8 @@ class MstLoaderInputValidator():
                 res = parse_and_validate_compile_params(compile_params)
             except Exception as e:
                 plpy.error(
-                    """Compile param check failed for: {} \n
-                    {}
+                    """Compile param check failed for: {0} \n
+                    {1}
                     """.format(compile_params, str(e)))
 
     def _validate_input_output_tables(self):
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index c9511d9..518d20a 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -30,6 +30,7 @@ import keras.optimizers as opt
 import keras.losses as losses
 
 import madlib_keras_serializer
+import madlib_keras_gpu_info
 from utilities.utilities import _assert
 from utilities.utilities import is_platform_pg
 
@@ -57,7 +58,7 @@ def reset_cuda_env(value):
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
             del os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
-def get_device_name_and_set_cuda_env(gpus_per_host, seg):
+def get_device_name_and_set_cuda_env_predict(gpus_per_host, seg):
     if gpus_per_host > 0:
         device_name = '/gpu:0'
         if is_platform_pg():
@@ -70,15 +71,32 @@ def get_device_name_and_set_cuda_env(gpus_per_host, seg):
         set_cuda_env('-1')
     return device_name
 
-def set_keras_session(device_name, gpus_per_host, segments_per_host):
+def get_device_name_and_set_cuda_env(use_gpus, gpu_count, seg):
+
+    if use_gpus:
+        if gpu_count > 0:
+            device_name = '/gpu:0'
+            if is_platform_pg():
+                cuda_visible_dev = ','.join([str(i) for i in range(gpu_count)])
+            else:
+                cuda_visible_dev = str(seg % gpu_count)
+            set_cuda_env(cuda_visible_dev)
+        else:
+            plpy.error("No gpus found on {}".format(seg))
+    else: # cpu only
+        device_name = '/cpu:0'
+        set_cuda_env('-1')
+    return device_name
+
+def set_keras_session(device_name, gpu_count, segments_per_host):
     with K.tf.device(device_name):
-        session = get_keras_session(device_name, gpus_per_host, segments_per_host)
+        session = get_keras_session(device_name, gpu_count, segments_per_host)
         K.set_session(session)
 
-def get_keras_session(device_name, gpus_per_host, segments_per_host):
+def get_keras_session(device_name, gpu_count, segments_per_host):
     config = K.tf.ConfigProto()
-    if gpus_per_host > 0:
-        memory_fraction = get_gpu_memory_fraction(gpus_per_host, segments_per_host)
+    if gpu_count > 0:
+        memory_fraction = get_gpu_memory_fraction(gpu_count, segments_per_host)
         config.gpu_options.allow_growth = False
         config.gpu_options.per_process_gpu_memory_fraction = memory_fraction
     session = tf.Session(config=config)
@@ -91,15 +109,15 @@ def clear_keras_session(sess = None):
     sess.close()
 
 
-def get_gpu_memory_fraction(gpus_per_host, segments_per_host):
+def get_gpu_memory_fraction(gpu_count, segments_per_host):
     """
     We cap the gpu memory usage to 90% of the total available gpu memory.
     This 90% is evenly distributed among the segments per gpu.
-    :param gpus_per_host:
+    :param gpu_count:
     :param segments_per_host:
     :return:
     """
-    return 0.9 / ceil(1.0 * segments_per_host / gpus_per_host)
+    return 0.9 / ceil(1.0 * segments_per_host / gpu_count)
 
 def get_model_shapes(model):
     model_shapes = []
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
index 55ee54e..cdda44a 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_evaluate.sql_in
@@ -40,7 +40,7 @@ SELECT madlib_keras_fit(
 
 -- Test that evaluate works as expected:
 DROP TABLE IF EXISTS evaluate_out;
-SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', 0);
+SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', FALSE);
 
 SELECT assert(loss  >= 0 AND
         metric  >= 0 AND
@@ -58,7 +58,7 @@ FROM evaluate_out;
 -- Test that evaluate errors out correctly if mst_key is given for non-multi model tables
 DROP TABLE IF EXISTS evaluate_out;
 SELECT assert(trap_error($TRAP$
-    SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', 0 ,1);
+    SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', FALSE ,1);
     $TRAP$) = 1, 'Should error out if mst_key is given for non-multi model tables');
 
 -- Test that evaluate errors out correctly if model_arch field missing from fit output
@@ -96,11 +96,11 @@ SELECT madlib_keras_fit_multiple_model(
     'cifar_10_multiple_model',
     'mst_table',
     6,
-    0
+    FALSE
 );
 
 DROP TABLE IF EXISTS evaluate_out;
-SELECT madlib_keras_evaluate('cifar_10_multiple_model', 'cifar_10_sample_batched', 'evaluate_out', 0, 2);
+SELECT madlib_keras_evaluate('cifar_10_multiple_model', 'cifar_10_sample_batched', 'evaluate_out', FALSE, 2);
 SELECT assert(relative_error(e.metric,i.training_metrics_final) < 0.00001 AND
         relative_error(e.loss,i.training_loss_final)  < 0.00001 AND
         e.metrics_type = '{accuracy}', 'Evaluate output validation failed.')
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index 591beec..8b8dfcb 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -122,7 +122,7 @@ SELECT assert(trap_error($TRAP$madlib_keras_fit(
     $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
     $$ batch_size=2, epochs=1, verbose=0 $$::text,
     3,
-    2,
+    True,
     'cifar_10_sample_val');$TRAP$) = 1,
        'Fit with gpus_per_host=2 must error out.');
 
@@ -267,7 +267,7 @@ SELECT madlib_keras_fit(
     $$ optimizer=Adam(epsilon=None), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
     $$ batch_size=2, epochs=1, verbose=0 $$::text,
     1,
-    0,
+    FALSE,
     NULL,
     NULL,
     NULL, 'model name', 'model desc');
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index 60927b0..66663f6 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -38,7 +38,7 @@ SELECT madlib_keras_fit(
 	$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
   $$batch_size=16, epochs=1$$,
 	3,
-	0
+	FALSE
 );
 
 SELECT assert(
@@ -76,7 +76,7 @@ SELECT madlib_keras_evaluate(
     'iris_model',
     'iris_data_val',
     'evaluate_out',
-    0);
+    FALSE);
 
 SELECT assert(loss >= 0 AND
         metric >= 0 AND
@@ -93,7 +93,7 @@ SELECT madlib_keras_fit(
 	$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
   $$batch_size=16, epochs=1$$,
 	3,
-	0
+	FALSE
 );
 
 SELECT assert(
@@ -131,7 +131,7 @@ SELECT madlib_keras_evaluate(
     'iris_model',
     'iris_data_one_hot_encoded_val',
     'evaluate_out',
-    0);
+    FALSE);
 
 SELECT assert(loss >= 0 AND
         metric >= 0 AND
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index adc771e..c36d48f 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -187,7 +187,7 @@ SELECT madlib_keras_fit_multiple_model(
 	'iris_multiple_model',
 	'mst_table_4row',
 	3,
-	0
+	FALSE
 );
 
 SELECT assert(
@@ -218,7 +218,7 @@ SELECT madlib_keras_fit_multiple_model(
 	'iris_multiple_model',
 	'mst_table',
 	6,
-	0,
+	FALSE,
 	'iris_data_one_hot_encoded_packed'
 );
 
@@ -295,7 +295,7 @@ SELECT madlib_keras_fit_multiple_model(
 	'iris_multiple_model',
 	'mst_table_1row',
 	3,
-	0,
+	FALSE,
 	NULL,
 	1,
 	FALSE,
@@ -339,7 +339,7 @@ SELECT madlib_keras_fit_multiple_model(
 	'iris_multiple_model',
 	'mst_table_4row',
 	3,
-	0
+	FALSE
 );
 
 SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index 355637f..eb4dd10 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -51,7 +51,7 @@ SELECT madlib_keras_fit_multiple_model(
 	'iris_multiple_model',
 	'mst_table',
 	3,
-	0
+	FALSE
 );
 
 SELECT assert(
@@ -106,7 +106,7 @@ SELECT madlib_keras_fit_multiple_model(
 	'iris_multiple_model',
 	'mst_table',
 	3,
-	0
+	FALSE
 );
 
 SELECT assert(
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
index ec5a1f2..5d52574 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
@@ -395,7 +395,7 @@ SELECT madlib_keras_fit_multiple_model(
     'iris_multiple_model',
     'mst_table',
     6,
-    0
+    FALSE
 );
 
 DROP TABLE IF EXISTS iris_predict;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 0ab09b7..6a7ec1f 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -146,7 +146,7 @@ SELECT madlib_keras_fit_multiple_model(
   'iris_multiple_model',
   'mst_table',
   3,
-  0, NULL, 1
+  FALSE, NULL, 1
 );
 
 DROP TABLE IF EXISTS iris_model_first_run;
@@ -161,7 +161,7 @@ SELECT madlib_keras_fit_multiple_model(
   'iris_multiple_model',
   'mst_table',
   3,
-  0,
+  FALSE,
   NULL, 1,
   TRUE -- warm_start
 );
@@ -205,7 +205,7 @@ SELECT madlib_keras_fit_multiple_model(
   'iris_multiple_model',
   'mst_table',
   3,
-  0, NULL, 1
+  FALSE, NULL, 1
 );
 
 DROP TABLE IF EXISTS iris_model_first_run;
@@ -221,7 +221,7 @@ SELECT madlib_keras_fit_multiple_model(
   'iris_multiple_model',
   'mst_table',
   3,
-  0, NULL, 1,
+  FALSE, NULL, 1,
   TRUE);
 
 
@@ -242,7 +242,7 @@ SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
   'iris_multiple_model',
   'mst_table',
   3,
-  0,
+  FALSE,
   NULL, 1,
   TRUE -- warm_start
 );$TRAP$) = 1, 'Warm start with extra mst keys should fail.');
@@ -308,7 +308,7 @@ SELECT madlib_keras_fit_multiple_model(
   'iris_multiple_model',
   'mst_table',
   3,
-  0
+  FALSE
 );
 
 SELECT assert(
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 901ebbd..6097286 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -34,8 +34,6 @@ import unittest
 from mock import *
 import plpy_mock as plpy
 
-
-
 # helper for multiplying array by int
 def mult(k,arr):
     return [ k*a for a in arr ]
@@ -65,8 +63,8 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.model_weights = [3,4,5,6]
         self.serialized_weights = np.array(self.model_weights, dtype=np.float32
                                            ).tostring()
-
-        self.all_seg_ids = [0,1,2]
+        self.dist_key_mapping = [0,1,2]
+        self.accessible_gpus_for_seg = [0]
 
         self.independent_var_real = [[[[0.5]]]] * 10
         self.dependent_var_int = [[0,1]] * 10
@@ -102,8 +100,9 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.all_seg_ids, self.total_images_per_seg, 0, 4,
-            previous_state.tostring(), True, **k)
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_state.tostring(), True, **k)
+
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
@@ -127,8 +126,9 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.all_seg_ids, self.total_images_per_seg, 0, 4,
-            previous_state.tostring(), False, **k)
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_state.tostring(), False, **k)
+
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
@@ -154,11 +154,12 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         k = {'SD': {}}
 
         new_state = self.subject.fit_transition(
-            None, self.dependent_var, self.independent_var ,
-            self.dependent_var_shape, self.independent_var_shape, self.model.to_json(),
-            self.compile_params, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, previous_weights.tostring(), True,
-            True, **k)
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, True, **k)
+
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
@@ -194,8 +195,9 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', True, **k)
+            self.model.to_json(), None, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, 'dummy_previous_state', True, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -223,8 +225,9 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', False, **k)
+            self.model.to_json(), None, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, 'dummy_previous_state', False, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -259,8 +262,9 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', True, True, **k)
+            self.model.to_json(), None, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, 'dummy_previous_state', True, True, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -293,11 +297,13 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.subject.compile_and_set_weights(self.model, self.compile_params,
                                              '/cpu:0', self.serialized_weights)
         k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', True, **k)
+            self.model.to_json(), None, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, 'dummy_previous_state', True, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -323,11 +329,13 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.subject.compile_and_set_weights(self.model, self.compile_params,
                                              '/cpu:0', self.serialized_weights)
         k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', False, **k)
+            self.model.to_json(), None, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, 'dummy_previous_state', False, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -359,11 +367,13 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.subject.compile_and_set_weights(self.model, self.compile_params,
                                              '/cpu:0', self.serialized_weights)
         k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape, self.model.to_json(),
-            None, self.fit_params, 0, self.all_seg_ids, self.total_images_per_seg,
-            0, 4, 'dummy_previous_state', True, True, **k)
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), None, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, 'dummy_previous_state', True, True, **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -399,20 +409,20 @@ class MadlibKerasFitTestCase(unittest.TestCase):
                          self.subject.fit_transition('dummy_state', [0], None,
                                                      'noshape', 'noshape',
                                                      'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], [3,3,3], 0, 4,
-                                                     'dummy_prev_state', **k))
+                                                     1, [0,1,2], 0, 4, [3,3,3], False,
+                                                     [0], 'dummy_prev_state', **k))
         self.assertEqual('dummy_state',
                          self.subject.fit_transition('dummy_state', None, [[0.5]],
                                                      'noshape', 'noshape',
                                                      'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], [3,3,3], 0, 4,
-                                                     'dummy_prev_state', **k))
+                                                     1, [0,1,2], 0, 4, [3,3,3], False,
+                                                     [0], 'dummy_prev_state', **k))
         self.assertEqual('dummy_state',
                          self.subject.fit_transition('dummy_state', None, None,
                                                      'noshape', 'noshape',
                                                      'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], [3,3,3], 0, 4,
-                                                     'dummy_prev_state', **k))
+                                                     1, [0,1,2], 0, 4, [3,3,3], False,
+                                                     [0], 'dummy_prev_state', **k))
 
     def test_fit_merge(self):
         image_count = self.total_images_per_seg[0]
@@ -707,6 +717,8 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         import madlib_keras_wrapper
         self.subject = madlib_keras_wrapper
 
+        self.use_gpus = False
+
     def tearDown(self):
         self.module_patcher.stop()
 
@@ -739,12 +751,12 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         gpus_per_host = 3
 
         self.assertEqual('/gpu:0', self.subject.get_device_name_and_set_cuda_env(
-            gpus_per_host, seg_id ))
+            True, gpus_per_host, seg_id ))
         self.assertEqual('0,1,2', os.environ['CUDA_VISIBLE_DEVICES'])
 
         gpus_per_host = 0
         self.assertEqual('/cpu:0', self.subject.get_device_name_and_set_cuda_env(
-            gpus_per_host, seg_id ))
+            False, gpus_per_host, seg_id ))
         self.assertEqual('-1', os.environ['CUDA_VISIBLE_DEVICES'])
 
     def test_get_device_name_and_set_cuda_env_gpdb(self):
@@ -753,12 +765,12 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         seg_id=3
         gpus_per_host=2
         self.assertEqual('/gpu:0', self.subject.get_device_name_and_set_cuda_env(
-            gpus_per_host, seg_id))
+            True, gpus_per_host, seg_id))
         self.assertEqual('1', os.environ['CUDA_VISIBLE_DEVICES'])
 
         gpus_per_host=0
         self.assertEqual('/cpu:0', self.subject.get_device_name_and_set_cuda_env(
-            gpus_per_host, seg_id))
+            False, gpus_per_host, seg_id))
         self.assertEqual('-1', os.environ['CUDA_VISIBLE_DEVICES'])
 
 
@@ -1038,28 +1050,32 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
         self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, None, False, 'module_name')
+            'dep_varname', 'independent_varname', 5, None, False, False, [0],
+            'module_name')
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_True_num(self):
         self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, 3, False, 'module_name')
+            'dep_varname', 'independent_varname', 5, 3, False, False, [0],
+            'module_name')
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_zero(self):
         self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, 0, False, 'module_name')
+            'dep_varname', 'independent_varname', 5, 0, False, False, [0],
+            'module_name')
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_greater(self):
         self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
-            'dep_varname', 'independent_varname', 5, 6, False, 'module_name')
+            'dep_varname', 'independent_varname', 5, 6, False, False, [0],
+            'module_name')
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
 
@@ -1159,6 +1175,29 @@ class InputValidatorTestCase(unittest.TestCase):
                 self.module_name, self.model_arch_table, None)
         self.assertIn('id', str(error.exception).lower())
 
+    def test_validate_gpu_config_with_gpu_all_segments(self):
+
+        self.plpy_mock_execute.return_value = [{'__internal_gpu_config__': 'all_segments'}]
+        obj = self.subject._validate_gpu_config(self.module_name, 'foo', [1])
+
+    def test_validate_gpu_config_no_gpu_all_segments(self):
+
+        self.plpy_mock_execute.return_value = [{'__internal_gpu_config__': 'all_segments'}]
+        with self.assertRaises(plpy.PLPYException) as error:
+            obj = self.subject._validate_gpu_config(self.module_name, 'foo', [0])
+        self.assertIn('missing gpus', str(error.exception).lower())
+
+    def test_validate_gpu_config_with_gpu_valid_seg_list(self):
+
+        self.plpy_mock_execute.return_value = [{'__internal_gpu_config__': [0,1]}]
+        obj = self.subject._validate_gpu_config(self.module_name, 'foo', [1,1,0,1])
+
+    def test_validate_gpu_config_with_gpu_invalid_seg_list(self):
+
+        self.plpy_mock_execute.return_value = [{'__internal_gpu_config__': [0,1]}]
+        with self.assertRaises(plpy.PLPYException) as error:
+            obj = self.subject._validate_gpu_config(self.module_name, 'foo', [1,0,0,1])
+        self.assertIn('does not have gpu', str(error.exception).lower())
 
 class MadlibSerializerTestCase(unittest.TestCase):
     def setUp(self):
@@ -1289,6 +1328,45 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
                          self.subject.strip_trailing_nulls_from_class_values(
                 [None, None]))
 
+    def test_get_gpus_per_one_seg_gpu_gpdb(self):
+
+        self.subject.is_platform_pg = Mock(return_value = False)
+
+        self.plpy_mock_execute.side_effect = \
+            [ [],
+            [ {'hostname': 'mdw0', 'count' : 1}],
+            [ {'hostname': 'mdw0', 'segment_id' : 0},
+              {'hostname': 'mdw1', 'segment_id' : 1},
+              {'hostname': 'mdw2', 'segment_id' : 2}
+            ]]
+
+        self.assertEqual([1,0,0], self.subject.get_accessible_gpus_for_seg(
+            'schema_madlib', 2, 'foo'))
+
+    def test_get_gpus_per_mult_seg_gpu_gpdb(self):
+
+        self.subject.is_platform_pg = Mock(return_value = False)
+
+        self.plpy_mock_execute.side_effect = \
+            [[],
+            [ {'hostname': 'mdw0', 'count' : 1}],
+            [ {'hostname': 'mdw0', 'segment_id' : 0},
+              {'hostname': 'mdw0', 'segment_id' : 1},
+              {'hostname': 'mdw1', 'segment_id' : 2},
+              {'hostname': 'mdw1', 'segment_id' : 3}
+            ]]
+
+        self.assertEqual([1,1,0,0], self.subject.get_accessible_gpus_for_seg(
+            'schema_madlib', 2, 'foo'))
+
+    def test_get_gpus_per_no_gpu_gpdb(self):
+
+        self.subject.is_platform_pg = Mock(return_value = False)
+
+        self.plpy_mock_execute.side_effect = [[],[]]
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject.get_accessible_gpus_for_seg('schema_madlib', 2, 'foo')
+        self.assertIn('no gpus configured on hosts', str(error.exception).lower())
 
 class MadlibKerasEvaluationTestCase(unittest.TestCase):
     def setUp(self):
@@ -1316,7 +1394,8 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
                                            ).tostring()
         self.loss = 0.5947071313858032
         self.accuracy = 1.0
-        self.all_seg_ids = [0,1,2]
+        self.dist_key_mapping = [0,1,2]
+        self.accessible_gpus_for_seg = [0]
 
         #self.model.evaluate = Mock(return_value = [self.loss, self.accuracy])
 
@@ -1351,8 +1430,9 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
-            self.serialized_weights, self.compile_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, True, **k)
+            self.serialized_weights, self.compile_params, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
 
         agg_loss, agg_accuracy, image_count = new_state
 
@@ -1372,11 +1452,14 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         self.subject.K.clear_session.reset_mock()
         k = {'SD' : {}}
         state = [0,0,0]
+
         new_state = self.subject.internal_keras_eval_transition(
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), self.serialized_weights, self.compile_params,
-            0, self.all_seg_ids, self.total_images_per_seg, 0, 3, False, **k)
+            self.model.to_json(),
+            self.serialized_weights, self.compile_params, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
         agg_loss, agg_accuracy, image_count = new_state
 
         self.assertEqual(ending_image_count, image_count)
@@ -1414,8 +1497,9 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
-            'dummy_model_weights', None, 0,self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, True, **k)
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
 
         agg_loss, agg_accuracy, image_count = new_state
 
@@ -1445,9 +1529,9 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
-            'dummy_model_weights', None, 0,self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, False, **k)
-
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
         agg_loss, agg_accuracy, image_count = new_state
 
         self.assertEqual(ending_image_count, image_count)
@@ -1485,9 +1569,9 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
-            'dummy_model_weights', None, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, True, **k)
-
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
         agg_loss, agg_accuracy, image_count = new_state
 
         self.assertEqual(ending_image_count, image_count)
@@ -1520,8 +1604,9 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
             state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(),
-            'dummy_model_weights', None, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 3, False, **k)
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
 
         agg_loss, agg_accuracy, image_count = new_state