You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by kh...@apache.org on 2021/01/29 00:04:17 UTC

[madlib] branch master updated (3b66baa -> f978b3b)

This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git.


    from 3b66baa  DL: Add multiple variable support
     new f833ac0  utilities: Add new function for getting data distribution per segment
     new f978b3b  DL: Fix gpu mem fraction calc when data isn't distributed to all segs

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../modules/deep_learning/madlib_keras.py_in       |  44 +++---
 .../modules/deep_learning/madlib_keras.sql_in      |  26 ++--
 .../madlib_keras_fit_multiple_model.py_in          |  21 ++-
 .../madlib_keras_fit_multiple_model.sql_in         |   2 +-
 .../deep_learning/madlib_keras_helper.py_in        |  41 +++++-
 .../deep_learning/madlib_keras_predict.py_in       |  12 +-
 .../deep_learning/test/madlib_keras_fit.sql_in     |  11 ++
 .../test/madlib_keras_fit_multiple.sql_in          |  15 +-
 .../test/madlib_keras_iris.setup.sql_in            |  18 +++
 .../test/unit_tests/test_madlib_keras.py_in        | 158 ++++++++++++++++-----
 10 files changed, 249 insertions(+), 99 deletions(-)


[madlib] 01/02: utilities: Add new function for getting data distribution per segment

Posted by kh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit f833ac08c59e576d4cdd80b3ad4c4115bfa6f93d
Author: Ekta Khanna <ek...@pivotal.io>
AuthorDate: Tue Jan 19 17:12:15 2021 -0800

    utilities: Add new function for getting data distribution per segment
    
    JIRA: MADLIB-1463
    
    This commit adds a new function which returns a list with count of
    segments on each host that the input table's data is distributed on.
    
    This function will be useful for the deep learning module, where using
    the image preprocessor, the user can choose to distribute the data only to
    a part of the cluster.
    
    Co-authored-by: Nikhil Kak <nk...@vmware.com>
---
 .../deep_learning/madlib_keras_helper.py_in        | 36 ++++++++++++++
 .../test/unit_tests/test_madlib_keras.py_in        | 58 ++++++++++++++++++++++
 2 files changed, 94 insertions(+)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 2dd17aa..735f1b2 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -410,3 +410,39 @@ def generate_row_string(configs_dict):
     if result_row_string[0] == ',':
         return result_row_string[1:]
     return result_row_string
+
+def get_data_distribution_per_segment(table_name):
+    """
+    Returns a list with count of segments on each host that the input
+    table's data is distributed on.
+    :param table_name: input table name
+    :return: len(return list) = total num of segments in cluster
+    Each index of the array/list represents a segment of the cluster. If the data
+    is not distributed on that segment, then that index's value will be set to zero.
+    Otherwise the value will be set to the count of segments that have the data on
+    that segment's host.
+    For e.g. If there are 2 hosts and 3 segs per host
+    host1 - seg0, seg1, seg2
+    host2 - seg3, seg4, seg5
+    If the data is distributed on seg0, seg1 and seg3 then the return value will be
+    [2,2,0,1,0,0]
+    """
+    if is_platform_pg():
+        return [1]
+    else:
+        res = plpy.execute("""
+                    WITH cte AS (SELECT DISTINCT(gp_segment_id)
+                                 FROM {table_name})
+                    SELECT content, count as cnt
+                        FROM gp_segment_configuration
+                        JOIN (SELECT hostname, count(*)
+                              FROM gp_segment_configuration
+                              WHERE content in (SELECT * FROM cte)
+                              GROUP BY hostname) a
+                        USING (hostname)
+                        WHERE content in (SELECT * FROM cte)
+                    ORDER BY 1""".format(table_name=table_name))
+        data_distribution_per_segment = [0] * get_seg_number()
+        for r in res:
+            data_distribution_per_segment[r['content']] = int(r['cnt'])
+        return data_distribution_per_segment
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 64395eb..31a61a8 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1629,6 +1629,64 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
             "This is not valid PostgresSQL: SELECT {}[1]".format(metrics)
         )
 
+    def test_get_data_distribution_per_segment_all_segments(self):
+        # config: 3 hosts 3 seg per host
+        # data: all segments on all hosts
+        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.get_seg_number = Mock(return_value=9)
+        self.plpy_mock_execute.side_effect = \
+            [[ {'content': 0, 'cnt' : 3},
+               {'content': 1, 'cnt' : 3},
+               {'content': 2, 'cnt' : 3},
+               {'content': 3, 'cnt' : 3},
+               {'content': 4, 'cnt' : 3},
+               {'content': 5, 'cnt' : 3},
+               {'content': 6, 'cnt' : 3},
+               {'content': 7, 'cnt' : 3},
+               {'content': 8, 'cnt' : 3}
+               ]]
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([3,3,3,3,3,3,3,3,3],res)
+
+    def test_get_data_distribution_per_segment_on_some_hosts(self):
+        # config: 3 hosts 3 seg per host
+        # data: all segments on 2 hosts
+        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.get_seg_number = Mock(return_value=9)
+        self.plpy_mock_execute.side_effect = \
+            [[ {'content': 0, 'cnt' : 3},
+               {'content': 1, 'cnt' : 3},
+               {'content': 2, 'cnt' : 3},
+               {'content': 3, 'cnt' : 3},
+               {'content': 4, 'cnt' : 3},
+               {'content': 5, 'cnt' : 3}
+               ]]
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([3,3,3,3,3,3,0,0,0],res)
+
+    def test_get_data_distribution_per_segment_some_segments(self):
+        # config: 3 hosts 3 seg per host
+        # data: all seg host 1, 2 seg on host 2 and 1 seg on host 3
+        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.get_seg_number = Mock(return_value=9)
+        self.plpy_mock_execute.side_effect = \
+            [[ {'content': 0, 'cnt' : 3},
+               {'content': 1, 'cnt' : 3},
+               {'content': 2, 'cnt' : 3},
+               {'content': 3, 'cnt' : 2},
+               {'content': 4, 'cnt' : 2},
+               {'content': 8, 'cnt' : 1}
+               ]]
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([3,3,3,2,2,0,0,0,1],res)
+
+    def test_get_data_distribution_per_segment_some_segments(self):
+        # config: 3 hosts 3 seg per host
+        # data: all seg host 1, 2 seg on host 2 and 1 seg on host 3
+        self.subject.is_platform_pg = Mock(return_value = True)
+        res = self.subject.get_data_distribution_per_segment('source_table')
+        self.assertEqual([1],res)
+
 class MadlibKerasEvaluationMergeFinalTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')


[madlib] 02/02: DL: Fix gpu mem fraction calc when data isn't distributed to all segs

Posted by kh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit f978b3b1087554da6e3986edef401879a4e77484
Author: Ekta Khanna <ek...@pivotal.io>
AuthorDate: Thu Jan 21 11:57:53 2021 -0800

    DL: Fix gpu mem fraction calc when data isn't distributed to all segs
    
    JIRA: MADLIB-1463
    
    Previously, the calculation of `gpu_mem_fraction` assumed that
    num_segments = all_segments which is not always the case. The user can
    pass in a distribution rules table to input preprocessor and the data
    can be distributed to less segments than the total number of segments on
    the cluster.
    
    This commit replaces the get_segments_per_host function call with
    get_data_distribution_per_segment which returns the actual distribution
    of the data instead of returning a list of all the segments. Using this,
    we can calculate the correct memory fraction.
    
    Co-authored-by: Nikhil Kak <nk...@vmware.com>
---
 .../modules/deep_learning/madlib_keras.py_in       |  44 +++++----
 .../modules/deep_learning/madlib_keras.sql_in      |  26 +++---
 .../madlib_keras_fit_multiple_model.py_in          |  21 +++--
 .../madlib_keras_fit_multiple_model.sql_in         |   2 +-
 .../deep_learning/madlib_keras_helper.py_in        |   5 +-
 .../deep_learning/madlib_keras_predict.py_in       |  12 +--
 .../deep_learning/test/madlib_keras_fit.sql_in     |  11 +++
 .../test/madlib_keras_fit_multiple.sql_in          |  15 ++--
 .../test/madlib_keras_iris.setup.sql_in            |  18 ++++
 .../test/unit_tests/test_madlib_keras.py_in        | 100 +++++++++++++--------
 10 files changed, 155 insertions(+), 99 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 45c3840..49892b6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -35,7 +35,6 @@ from internal.db_utils import quote_literal
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
 from utilities.utilities import is_platform_pg
-from utilities.utilities import get_segments_per_host
 from utilities.utilities import get_seg_number
 from utilities.utilities import madlib_version
 from utilities.utilities import unique_string
@@ -104,10 +103,12 @@ def fit(schema_madlib, source_table, model, model_arch_table,
     fit_params = "" if not fit_params else fit_params
     _assert(compile_params, "Compile parameters cannot be empty or NULL.")
 
-    segments_per_host = get_segments_per_host()
+    segments_per_host = get_data_distribution_per_segment(source_table)
     use_gpus = use_gpus if use_gpus else False
     if use_gpus:
-        accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib, segments_per_host, module_name)
+        accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib,
+                                                              segments_per_host,
+                                                              module_name)
     else:
         accessible_gpus_for_seg = get_seg_number()*[0]
 
@@ -258,7 +259,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
             {dist_key_col},
             ARRAY{dist_key_mapping},
             {gp_segment_id_col},
-            {segments_per_host},
+            ARRAY{segments_per_host},
             ARRAY{images_per_seg_train},
             ARRAY{accessible_gpus_for_seg},
             $1,
@@ -322,6 +323,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
                                                    model_arch,
                                                    serialized_weights, use_gpus,
                                                    accessible_gpus_for_seg,
+                                                   segments_per_host,
                                                    dist_key_mapping,
                                                    images_per_seg_train,
                                                    training_metrics,
@@ -341,6 +343,7 @@ def fit(schema_madlib, source_table, model, model_arch_table,
                                                            serialized_weights,
                                                            use_gpus,
                                                            accessible_gpus_for_seg,
+                                                           segments_per_host,
                                                            dist_key_mapping_val,
                                                            images_per_seg_val,
                                                            validation_metrics,
@@ -523,11 +526,11 @@ def get_source_summary_table_dict(source_summary_table):
 
     return source_summary
 
-def compute_loss_and_metrics(schema_madlib, table, columns_dict, compile_params, model_arch,
-                             serialized_weights, use_gpus, accessible_gpus_for_seg,
-                             dist_key_mapping, images_per_seg_val,
-                             metrics_list, loss_list,
-                             should_clear_session, custom_fn_map,
+def compute_loss_and_metrics(schema_madlib, table, columns_dict, compile_params,
+                             model_arch, serialized_weights, use_gpus,
+                             accessible_gpus_for_seg, segments_per_host,
+                             dist_key_mapping, images_per_seg_val, metrics_list,
+                             loss_list, should_clear_session, custom_fn_map,
                              model_table=None, mst_key=None, is_train=True):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
@@ -542,6 +545,7 @@ def compute_loss_and_metrics(schema_madlib, table, columns_dict, compile_params,
                                                    serialized_weights,
                                                    use_gpus,
                                                    accessible_gpus_for_seg,
+                                                   segments_per_host,
                                                    dist_key_mapping,
                                                    images_per_seg_val,
                                                    should_clear_session,
@@ -643,7 +647,7 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
 
     segment_model, sess = get_init_model_and_sess(GD, device_name,
         accessible_gpus_for_seg[current_seg_id],
-        segments_per_host,
+        segments_per_host[current_seg_id],
         model_architecture, compile_params,
         custom_function_map)
 
@@ -761,7 +765,7 @@ def fit_multiple_transition_caching(dependent_var, independent_var, dependent_va
         device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
         segment_model, sess = get_init_model_and_sess(GD, device_name,
                                                       accessible_gpus_for_seg[current_seg_id],
-                                                      segments_per_host,
+                                                      segments_per_host[current_seg_id],
                                                       model_architecture, compile_params,
                                                       custom_function_map)
 
@@ -876,9 +880,11 @@ def evaluate(schema_madlib, model_table, test_table, output_table,
 
     validate_evaluate(module_name, model_table, model_summary_table, test_table, test_summary_table, output_table, is_mult_model)
 
-    segments_per_host = get_segments_per_host()
+    segments_per_host = get_data_distribution_per_segment(test_table)
     if use_gpus:
-        accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib, segments_per_host, module_name)
+        accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib,
+                                                              segments_per_host,
+                                                              module_name)
     else:
         accessible_gpus_for_seg = get_seg_number()*[0]
 
@@ -921,8 +927,8 @@ def evaluate(schema_madlib, model_table, test_table, output_table,
     loss_metric = \
         get_loss_metric_from_keras_eval(
             schema_madlib, test_table, columns_dict, compile_params, model_arch,
-            model_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
-            images_per_seg, custom_function_map=custom_function_map)
+            model_weights, use_gpus, accessible_gpus_for_seg, segments_per_host,
+            dist_key_mapping, images_per_seg, custom_function_map=custom_function_map)
 
     loss = loss_metric[0]
     metric = loss_metric[1]
@@ -967,7 +973,8 @@ def validate_evaluate(module_name, model_table, model_summary_table, test_table,
 
 def get_loss_metric_from_keras_eval(schema_madlib, table, columns_dict, compile_params,
                                     model_arch, serialized_weights, use_gpus,
-                                    accessible_gpus_for_seg, dist_key_mapping, images_per_seg,
+                                    accessible_gpus_for_seg, segments_per_host,
+                                    dist_key_mapping, images_per_seg,
                                     should_clear_session=True, custom_function_map=None,
                                     model_table=None, mst_key=None, is_train=True):
     """
@@ -977,7 +984,6 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, columns_dict, compile_
 
     dist_key_col = '0' if is_platform_pg() else '__table__.{0}'.format(DISTRIBUTION_KEY_COLNAME)
     gp_segment_id_col = '0' if is_platform_pg() else '__table__.{0}'.format(GP_SEGMENT_ID_COLNAME)
-    segments_per_host = get_segments_per_host()
 
     """
     This function will call the internal keras evaluate function to get the loss
@@ -1009,7 +1015,7 @@ def get_loss_metric_from_keras_eval(schema_madlib, table, columns_dict, compile_
                                             {dist_key_col},
                                             ARRAY{dist_key_mapping},
                                             {gp_segment_id_col},
-                                            {segments_per_host},
+                                            ARRAY{segments_per_host},
                                             ARRAY{images_per_seg},
                                             ARRAY{accessible_gpus_for_seg},
                                             {should_clear_session},
@@ -1090,7 +1096,7 @@ def internal_keras_eval_transition(state, dependent_var, independent_var,
 
     segment_model, sess = get_init_model_and_sess(GD, device_name,
                                                   accessible_gpus_for_seg[current_seg_id],
-                                                  segments_per_host,
+                                                  segments_per_host[current_seg_id],
                                                   model_architecture,
                                                   compile_params, custom_function_map)
     if not agg_image_count:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 9896fae..da61e8e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1676,7 +1676,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     dist_key                    INTEGER,
     dist_key_mapping            INTEGER[],
     current_seg_id              INTEGER,
-    segments_per_host           INTEGER,
+    segments_per_host           INTEGER[],
     images_per_seg              INTEGER[],
     accessible_gpus_for_seg     INTEGER[],
     prev_serialized_weights     BYTEA,
@@ -1718,7 +1718,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_wide(
     dist_key                    INTEGER,
     dist_key_mapping            INTEGER[],
     current_seg_id              INTEGER,
-    segments_per_host           INTEGER,
+    segments_per_host           INTEGER[],
     images_per_seg              INTEGER[],
     accessible_gpus_for_seg     INTEGER[],
     prev_serialized_weights     BYTEA,
@@ -1797,7 +1797,7 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step(
     /* dist_key */               INTEGER,
     /* dist_key_mapping */       INTEGER[],
     /* current_seg_id */         INTEGER,
-    /* segments_per_host */      INTEGER,
+    /* segments_per_host */      INTEGER[],
     /* images_per_seg */         INTEGER[],
     /* segments_per_host  */     INTEGER[],
     /* prev_serialized_weights */BYTEA,
@@ -1821,18 +1821,17 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* dist_key */               INTEGER,
     /* dist_key_mapping */       INTEGER[],
     /* current_seg_id */         INTEGER,
-    /* segments_per_host */      INTEGER,
+    /* segments_per_host */      INTEGER[],
     /* images_per_seg */         INTEGER[],
     /* segments_per_host  */     INTEGER[],
     /* prev_serialized_weights */BYTEA,
     /* custom_loss_cfunction */  BYTEA
-)(
+    )(
     STYPE=BYTEA,
     SFUNC=MADLIB_SCHEMA.fit_transition_wide,
     m4_ifdef(`__POSTGRESQL__', `', `prefunc=MADLIB_SCHEMA.fit_merge,')
     FINALFUNC=MADLIB_SCHEMA.fit_final
 );
-
 DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step(
     BYTEA[],
     BYTEA[],
@@ -1844,7 +1843,7 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step(
     INTEGER,
     INTEGER[],
     INTEGER,
-    INTEGER,
+    INTEGER[],
     INTEGER[],
     INTEGER[],
     BYTEA,
@@ -1860,9 +1859,9 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* dist_key */               INTEGER,
     /* dist_key_mapping */       INTEGER[],
     /* current_seg_id */         INTEGER,
-    /* segments_per_host */      INTEGER,
+    /* segments_per_host */      INTEGER[],
     /* images_per_seg */         INTEGER[],
-    /* segments_per_host  */     INTEGER[],
+    /* accessible_gpus_for_seg */INTEGER[],
     /* prev_serialized_weights */BYTEA,
     /* custom_loss_cfunction */  BYTEA
 )(
@@ -1963,7 +1962,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
     seg_ids            INTEGER[],
     images_per_seg     INTEGER[],
     gpus_per_host      INTEGER,
-    segments_per_host  INTEGER
+    segments_per_host  INTEGER[]
 ) RETURNS DOUBLE PRECISION[] AS $$
     PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras_predict')
     return madlib_keras_predict.internal_keras_predict_wide(**globals())
@@ -2078,7 +2077,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_eval_transition(
     dist_key                           INTEGER,
     dist_key_mapping                   INTEGER[],
     current_seg_id                     INTEGER,
-    segments_per_host                  INTEGER,
+    segments_per_host                  INTEGER[],
     images_per_seg                     INTEGER[],
     accessible_gpus_for_seg            INTEGER[],
     should_clear_session               BOOLEAN,
@@ -2117,9 +2116,8 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.internal_keras_evaluate(
                                        INTEGER,
                                        INTEGER[],
                                        INTEGER,
-                                       INTEGER,
                                        INTEGER[],
-                                       BOOLEAN,
+                                       INTEGER[],
                                        INTEGER[],
                                        BOOLEAN,
                                        BYTEA);
@@ -2135,7 +2133,7 @@ CREATE AGGREGATE MADLIB_SCHEMA.internal_keras_evaluate(
     /* dist_key */                  INTEGER,
     /* dist_key_mapping */          INTEGER[],
     /* current_seg_id */            INTEGER,
-    /* segments_per_host */         INTEGER,
+    /* segments_per_host */         INTEGER[],
     /* images_per_seg*/             INTEGER[],
     /* accessible_gpus_for_seg */   INTEGER[],
     /* should_clear_session */      BOOLEAN,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 5decb4c..441c155 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -46,8 +46,6 @@ from utilities.utilities import rotate
 from utilities.utilities import madlib_version
 from utilities.utilities import is_platform_pg
 from utilities.utilities import get_seg_number
-from utilities.utilities import get_segments_per_host
-from utilities.utilities import rename_table
 import utilities.debug as DEBUG
 from utilities.debug import plpy_prepare
 from utilities.debug import plpy_execute
@@ -172,7 +170,6 @@ class FitMultipleModel(object):
         self.columns_dict['val_ind_shape_cols'] = self.val_ind_shape_cols
 
         self.use_gpus = use_gpus if use_gpus else False
-        self.segments_per_host = get_segments_per_host()
         self.model_input_tbl = unique_string('model_input')
         self.model_output_tbl = unique_string('model_output')
         self.schedule_tbl = unique_string('schedule')
@@ -182,6 +179,7 @@ class FitMultipleModel(object):
         self.rotate_schedule_tbl_plan = self.add_object_maps_plan = None
         self.hop_plan = self.udf_plan = None
 
+        self.segments_per_host = get_data_distribution_per_segment(source_table)
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                 self.schema_madlib, self.segments_per_host, self.module_name)
@@ -190,7 +188,7 @@ class FitMultipleModel(object):
 
         self.original_model_output_tbl = model_output_table
         if not self.original_model_output_tbl:
-	       plpy.error("Must specify an output table.")
+            plpy.error("Must specify an output table.")
 
         self.model_info_tbl = add_postfix(
             self.original_model_output_tbl, '_info')
@@ -375,12 +373,13 @@ class FitMultipleModel(object):
                     None,
                     self.use_gpus,
                     self.accessible_gpus_for_seg,
-                    seg_ids,
-                    images_per_seg,
-                    [], [], True,
-                    mst[self.object_map_col],
-                    self.model_output_tbl,
-                    mst[self.mst_key_col],
+                    self.segments_per_host,
+                seg_ids,
+                images_per_seg,
+                [], [], True,
+                mst[self.object_map_col],
+                self.model_output_tbl,
+                mst[self.mst_key_col],
                     is_train)
             total_eval_compute_time += eval_compute_time
             mst_metric_eval_time[mst[self.mst_key_col]] \
@@ -962,7 +961,7 @@ class FitMultipleModel(object):
                             src.{self.dist_key_col},
                             ARRAY{self.dist_key_mapping},
                             src.{self.gp_segment_id_col},
-                            {self.segments_per_host},
+                            ARRAY{self.segments_per_host},
                             ARRAY{self.images_per_seg_train},
                             ARRAY{self.accessible_gpus_for_seg},
                             model_in.{self.model_weights_col}::BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 3f478eb..42fd7d9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1518,7 +1518,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
     dist_key                   INTEGER,
     dist_key_mapping           INTEGER[],
     current_seg_id             INTEGER,
-    segments_per_host          INTEGER,
+    segments_per_host          INTEGER[],
     images_per_seg             INTEGER[],
     accessible_gpus_for_seg    INTEGER[],
     serialized_weights         BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 735f1b2..243acd1 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -20,8 +20,9 @@
 import numpy as np
 from model_arch_info import ModelArchSchema
 from utilities.utilities import add_postfix
-from utilities.utilities import unique_string
+from utilities.utilities import get_seg_number
 from utilities.utilities import is_platform_pg
+from utilities.utilities import unique_string
 from utilities.validate_args import table_exists
 from madlib_keras_gpu_info import GPUInfoFunctions
 import plpy
@@ -303,7 +304,7 @@ def get_accessible_gpus_for_seg(schema_madlib, segments_per_host, module_name):
         for i in seg_query_result:
             if i['hostname'] in host_dict.keys():
                 accessible_gpus_for_seg[i['segment_id']] = host_dict[i['hostname']]
-            if 0 < accessible_gpus_for_seg[i['segment_id']] < segments_per_host and warning_flag:
+            if 0 < accessible_gpus_for_seg[i['segment_id']] < segments_per_host[i['segment_id']] and warning_flag:
                 plpy.warning(
                     'The number of GPUs per segment host is less than the number of '
                     'segments per segment host. When different segments share the '
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index d6b362d..053a5f9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -26,7 +26,6 @@ from predict_input_params import PredictParamsProcessor
 from utilities.control import MinWarning
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
-from utilities.utilities import get_segments_per_host
 from utilities.utilities import unique_string
 from utilities.utilities import get_psql_type
 from utilities.utilities import split_quoted_delimited_str
@@ -55,10 +54,12 @@ class BasePredict():
         self.pred_type = pred_type
         self.module_name = module_name
 
-        self.segments_per_host = get_segments_per_host()
         self.use_gpus = use_gpus if use_gpus else False
+        self.segments_per_host = get_data_distribution_per_segment(test_table)
         if self.use_gpus:
-            accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib, self.segments_per_host, self.module_name)
+            accessible_gpus_for_seg = get_accessible_gpus_for_seg(schema_madlib,
+                                                                  self.segments_per_host,
+                                                                  self.module_name)
             _assert(len(set(accessible_gpus_for_seg)) == 1,
                 '{0}: Asymmetric gpu configurations are not supported'.format(self.module_name))
             self.gpus_per_host = accessible_gpus_for_seg[0]
@@ -83,7 +84,6 @@ class BasePredict():
         gp_segment_id_col, seg_ids_test, \
         images_per_seg_test = get_image_count_per_seg_for_non_minibatched_data_from_db(
             self.test_table)
-        segments_per_host = get_segments_per_host()
 
         if self.pred_type == 1:
             rank_create_sql = ""
@@ -177,7 +177,7 @@ class BasePredict():
                                 ARRAY{seg_ids_test},
                                 ARRAY{images_per_seg_test},
                                 {self.gpus_per_host},
-                                {segments_per_host})) AS prob
+                                ARRAY{self.segments_per_host})) AS prob
 
                             FROM {self.test_table}
                             LEFT JOIN
@@ -364,7 +364,7 @@ def internal_keras_predict(independent_var, model_architecture, model_weights,
     try:
         device_name = get_device_name_and_set_cuda_env(gpus_per_host, current_seg_id)
         if model_key not in SD:
-            set_keras_session(device_name, gpus_per_host, segments_per_host)
+            set_keras_session(device_name, gpus_per_host, segments_per_host[current_seg_id])
             model = model_from_json(model_architecture)
             set_model_weights(model, model_weights)
             SD[model_key] = model
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index 72365de..988d1f3 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -425,6 +425,17 @@ SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
     3);$TRAP$) = 1,
     'Object table not specified for custom function in compile_params.');
 
+--- Test fit with table that is not distributed on all the 3 segments
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'iris_data_2seg_packed',
+    'keras_saved_out',
+    'iris_model_arch',
+    2,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    2);
+
 -- Test GD is cleared
 -- Setup
 CREATE OR REPLACE FUNCTION get_gd_keys_len()
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
index b6ce525..9d113b1 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
@@ -99,7 +99,7 @@ CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_mult
     dist_key                    INTEGER,
     dist_key_mapping            INTEGER[],
     current_seg_id              INTEGER,
-    segments_per_host           INTEGER,
+    segments_per_host           INTEGER[],
     images_per_seg              INTEGER[],
     accessible_gpus_for_seg     INTEGER[],
     serialized_weights          BYTEA,
@@ -146,7 +146,7 @@ $$ LANGUAGE plpythonu VOLATILE;
 
 CREATE OR REPLACE FUNCTION validate_transition_function_params(
     current_seg_id                       INTEGER,
-    segments_per_host                    INTEGER,
+    segments_per_host                    INTEGER[],
     images_per_seg                       INTEGER[],
     expected_num_calls                   INTEGER,
     expected_dist_key                    INTEGER,
@@ -270,7 +270,8 @@ $$
     fit_mult.images_per_seg_train = images_per_seg
     fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
     fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
-    fit_mult.segments_per_host = num_data_segs
+    data_distribution_per_seg = [num_data_segs] * num_dist_keys
+    fit_mult.segments_per_host = data_distribution_per_seg
 
     fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
     if num_models < num_dist_keys:
@@ -296,14 +297,14 @@ $$
                 ORDER BY __dist_key__  -- This would be gp_segment_id if it weren't a simulation
             ) AS expected_dist_key_mapping,
             ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
-            {num_data_segs} AS segments_per_host,
+            ARRAY{data_distribution_per_seg} AS segments_per_host,
             __dist_key__
         FROM {fm.source_table}
         GROUP BY __dist_key__
         DISTRIBUTED BY (__dist_key__);
     """.format(
             fm=fit_mult,
-            num_data_segs=num_data_segs,
+            data_distribution_per_seg=data_distribution_per_seg,
             exp_table=expected_distkey_mappings_tbl
         )
     plpy.execute(create_distkey_map_tbl_cmd)
@@ -427,7 +428,7 @@ SELECT test_run_training('src_3segs', 0, False, False, False);
     CREATE TABLE validate_params_results AS
         SELECT validate_transition_function_params(
             s.gp_segment_id,
-            3,
+            ARRAY[3, 3, 3],
             s.expected_images_per_seg,
             5,                 -- expected num_calls (per dist_key)
             s.__dist_key__,
@@ -773,7 +774,7 @@ DROP FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model(
     dist_key                    INTEGER,
     dist_key_mapping            INTEGER[],
     current_seg_id              INTEGER,
-    segments_per_host           INTEGER,
+    segments_per_host           INTEGER[],
     images_per_seg              INTEGER[],
     accessible_gpus_for_seg     INTEGER[],
     serialized_weights          BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
index 67b1aa9..7f68268 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
@@ -302,6 +302,24 @@ SELECT training_preprocessor_dl('iris_test',         -- Source table
                                 2                    -- buffer_size  (15 buffers)
                                 );
 
+
+-- Assuming that there are only 3 segments, we want to distribute the data on < 3 segs
+DROP TABLE IF EXISTS segments_to_use;
+CREATE TABLE segments_to_use (dbid INTEGER, notes TEXT);
+INSERT INTO segments_to_use VALUES (2, 'GPU segment');
+INSERT INTO segments_to_use VALUES (3, 'GPU segment');
+
+DROP TABLE IF EXISTS iris_data_2seg_packed, iris_data_2seg_packed_summary;
+SELECT training_preprocessor_dl('iris_data',         -- Source table
+                                'iris_data_2seg_packed',  -- Output table
+                                'class_text',        -- Dependent variable
+                                'attributes',        -- Independent variable
+                                NULL,
+                                NULL,
+                                NULL,
+                                'segments_to_use'
+                                );
+
 -- Create multi io dataset
 
 DROP TABLE IF EXISTS iris_mult;
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 31a61a8..bb40fba 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -88,6 +88,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
         # We test on segment 0, which has 3 buffers filled with 10 identical
         #  images each, or 30 images total
         self.total_images_per_seg = [3*len(self.dependent_var_int),20,40]
+        self.data_segments_per_host = [4]
 
         self.dummy_prev_weights = 'dummy weights'
 
@@ -108,8 +109,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             None, [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_state.tostring(),  **kwargs)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            previous_state.tostring(), **kwargs)
 
         image_count = kwargs['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
@@ -123,9 +125,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             None, [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.serialized_weights,
-             True, **kwargs)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.serialized_weights, True, **kwargs)
 
         self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
         image_count = kwargs['GD']['agg_image_count']
@@ -139,8 +141,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.serialized_weights, True, **k)
 
         self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
         image_count = k['GD']['agg_image_count']
@@ -160,8 +163,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             state, [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, **kwargs)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.dummy_prev_weights, **kwargs)
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
@@ -177,9 +181,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             None, [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True,
-            **kwargs)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.dummy_prev_weights, True, True, **kwargs)
 
         self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
         image_count = kwargs['GD']['agg_image_count']
@@ -203,8 +207,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.serialized_weights, True, **k)
 
         self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
         image_count = k['GD']['agg_image_count']
@@ -226,9 +231,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             state, [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_state.tostring(),
-            **kwargs)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            previous_state.tostring(), **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -249,7 +254,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(),
             self.serialized_weights, self.compile_params, 0,
-            self.dist_key_mapping, 0, 4,
+            self.dist_key_mapping, 0, self.data_segments_per_host,
             self.total_images_per_seg, self.accessible_gpus_for_seg,
             last_iteration, None, **kwargs)
 
@@ -275,7 +280,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(),
             'dummy_model_weights', None, 0,
-            self.dist_key_mapping, 0, 4,
+            self.dist_key_mapping, 0, self.data_segments_per_host,
             self.total_images_per_seg, self.accessible_gpus_for_seg,
             last_iteration, **kwargs)
         agg_loss, agg_accuracy, image_count = new_state
@@ -299,7 +304,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(),
             'dummy_model_weights', None, 0,
-            self.dist_key_mapping, 0, 4,
+            self.dist_key_mapping, 0, self.data_segments_per_host,
             self.total_images_per_seg, self.accessible_gpus_for_seg,
             last_iteration, **kwargs)
 
@@ -317,9 +322,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             None, [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights,
-            True, **kwargs)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.dummy_prev_weights, True, **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
 
@@ -349,8 +354,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [self.dependent_var] , [self.independent_var],
             [self.dependent_var_shape], [self.independent_var_shape],
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.serialized_weights, False, **k)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.serialized_weights, False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
         state = np.fromstring(new_state, dtype=np.float32)
@@ -389,8 +395,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [None], [None],
             [None], [None],
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.serialized_weights, False, **k)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.serialized_weights, False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
         weights = np.fromstring(new_state, dtype=np.float32)
@@ -423,8 +430,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
             [None], [None],
             [None], [None],
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
+            self.dist_key_mapping, 0, self.data_segments_per_host,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
+            self.serialized_weights, True, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
 
@@ -759,6 +767,7 @@ class InternalKerasPredictTestCase(unittest.TestCase):
 
         self.independent_var = [[[[240]]]]
         self.total_images_per_seg = [3,3,4]
+        self.data_segments_per_host = [4]
 
     def tearDown(self):
         self.module_patcher.stop()
@@ -772,7 +781,7 @@ class InternalKerasPredictTestCase(unittest.TestCase):
         result = self.subject.internal_keras_predict(
             self.independent_var, self.model.to_json(),
             serialized_weights, 255, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, **k)
+            self.total_images_per_seg, 0, self.data_segments_per_host, **k)
         self.assertEqual(2, len(result))
         self.assertEqual(1,  k['SD']['row_count'])
         self.assertEqual(True, 'segment_model_predict' in k['SD'])
@@ -784,7 +793,8 @@ class InternalKerasPredictTestCase(unittest.TestCase):
         k['SD']['segment_model_predict'] = self.model
         result = self.subject.internal_keras_predict(
             self.independent_var, None, None, 255, 0,
-            self.all_seg_ids, self.total_images_per_seg, 0, 4, **k)
+            self.all_seg_ids, self.total_images_per_seg, 0,
+            self.data_segments_per_host, **k)
         self.assertEqual(2, len(result))
         self.assertEqual(2,  k['SD']['row_count'])
         self.assertEqual(True, 'segment_model_predict' in k['SD'])
@@ -797,7 +807,8 @@ class InternalKerasPredictTestCase(unittest.TestCase):
         k['SD']['segment_model_predict'] = self.model
         result = self.subject.internal_keras_predict(
             self.independent_var, None, None, 255, 0,
-            self.all_seg_ids, self.total_images_per_seg, 0, 4, **k)
+            self.all_seg_ids, self.total_images_per_seg, 0,
+            self.data_segments_per_host, **k)
 
         # we except len(result) to be 3 because we have 3 dense layers in the
         # architecture
@@ -818,7 +829,7 @@ class InternalKerasPredictTestCase(unittest.TestCase):
             self.subject.internal_keras_predict(
                 self.independent_var, self.model.to_json(), serialized_weights,
                 255, current_seg_id, self.all_seg_ids,
-                self.total_images_per_seg, 0, 4, **k)
+                self.total_images_per_seg, 0, self.data_segments_per_host, **k)
         self.assertEqual("ValueError('-1 is not in list',)", str(error.exception))
         self.assertEqual(False, 'row_count' in k['SD'])
         self.assertEqual(False, 'segment_model_predict' in k['SD'])
@@ -958,6 +969,11 @@ class MadlibKerasWrapperTestCase(unittest.TestCase):
         result = self.subject.get_gpu_memory_fraction(gpus_per_host, segments_per_host)
         self.assertEqual(result, 0.225)
 
+        gpus_per_host = 3
+        segments_per_host = 4
+        result = self.subject.get_gpu_memory_fraction(gpus_per_host, segments_per_host)
+        self.assertEqual(result, 0.45)
+
     def test_get_device_name_and_set_cuda_env_postgres(self):
         self.subject.is_platform_pg = Mock(return_value = True)
 
@@ -1255,10 +1271,14 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
         self.plpy_mock_execute = MagicMock()
         plpy.execute = self.plpy_mock_execute
 
+        self.plpy_mock_warning = MagicMock()
+        plpy.warning = self.plpy_mock_warning
+
         self.module_patcher = patch.dict('sys.modules', patches)
         self.module_patcher.start()
         import madlib_keras_validator
         self.subject = madlib_keras_validator
+        self.subject.FitCommonValidator._validate_common_args = Mock()
         self.dep_shape_cols = [[10,1,1,1]]
         self.ind_shape_cols = [[10,2]]
 
@@ -1267,7 +1287,6 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
 
 
     def test_is_valid_metrics_compute_frequency_True_None(self):
-        self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', self.dep_shape_cols,
@@ -1276,7 +1295,6 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_True_num(self):
-        self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', self.dep_shape_cols,
@@ -1285,7 +1303,6 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_zero(self):
-        self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', self.dep_shape_cols,
@@ -1294,7 +1311,6 @@ class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_greater(self):
-        self.subject.FitCommonValidator._validate_common_args = Mock()
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', self.dep_shape_cols,
@@ -1558,6 +1574,8 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
     def test_get_gpus_per_one_seg_gpu_gpdb(self):
 
         self.subject.is_platform_pg = Mock(return_value = False)
+        self.plpy_mock_warning = MagicMock()
+        plpy.warning = self.plpy_mock_warning
 
         self.plpy_mock_execute.side_effect = \
             [ [],
@@ -1569,11 +1587,14 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
             ]]
 
         self.assertEqual([1,0,0], self.subject.get_accessible_gpus_for_seg(
-            'schema_madlib', 2, 'foo'))
+            'schema_madlib', [1,1,1], 'foo'))
+        self.assertEqual(0, self.plpy_mock_warning.call_count)
 
     def test_get_gpus_per_mult_seg_gpu_gpdb(self):
 
         self.subject.is_platform_pg = Mock(return_value = False)
+        self.plpy_mock_warning = MagicMock()
+        plpy.warning = self.plpy_mock_warning
 
         self.plpy_mock_execute.side_effect = \
             [[],
@@ -1586,7 +1607,8 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
             ]]
 
         self.assertEqual([1,1,0,0], self.subject.get_accessible_gpus_for_seg(
-            'schema_madlib', 2, 'foo'))
+            'schema_madlib', [2,2,2,2], 'foo'))
+        self.assertEqual(1, self.plpy_mock_warning.call_count)
 
     def test_get_gpus_per_no_gpu_gpdb(self):