You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by do...@apache.org on 2020/09/29 18:08:02 UTC

[madlib] 01/04: DL: Implement caching for fit_multiple_model

This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit a119fe882337f68e4105f5cd24179b4d87121e00
Author: Ekta Khanna <ek...@vmware.com>
AuthorDate: Thu Sep 10 12:31:28 2020 -0700

    DL: Implement caching for fit_multiple_model
    
    Currently passing around independent and dependent vars to the
    transition function is what takes up most of the time.
    As part of this commit, add a new fit_multipl_transition function that
    reads all the rows (for each seg) into the cache(SD) for the very first
    hop and for each subsequent hop/iteration, the data is read from the
    cache instead of table and cleared out at the final training call. This
    helps reduces the time to pass along the data to the transition function.
    Since, the data is cached into memory, the memory usage per segment
    increases significantly. To avoid this, a new optional param
    `use_caching` is added to madlib_keras_fit_multiple_model(), that can be
    set to TRUE if the memory on each segment meets the following
    calculation:
    
       IND_SZ (indep var size of each row) = ((image_dimension)*4)*(#of images per buffer)
       DEP_SZ (indep var size of each row) = (#DEP_VAR * 4)*(#of images per buffer)
       memory_data = (#seg_per_host) * (#rows_per_seg * IND_SZ) + (#seg_per_host) * (#rows_per_seg * DEP_SZ)
       memory_model = model_size * #models_per_seg * #seg_per_host
       total_memory = memory_data + memory_model
---
 .../modules/deep_learning/madlib_keras.py_in       |  89 ++++++++-
 .../madlib_keras_fit_multiple_model.py_in          |  58 ++++--
 .../madlib_keras_fit_multiple_model.sql_in         |  17 +-
 .../test/madlib_keras_model_selection.sql_in       | 200 +++++++++++++++++++-
 .../test/unit_tests/test_madlib_keras.py_in        | 202 ++++++++++++++++++++-
 5 files changed, 545 insertions(+), 21 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index e8eac71..00889b9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -523,7 +523,7 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
-                                  agg_image_count, total_images)
+                                       agg_image_count, total_images)
     if is_last_row:
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
@@ -531,6 +531,93 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
 
     return return_state
 
+def fit_multiple_transition(state, dependent_var, independent_var, dependent_var_shape,
+                             independent_var_shape, model_architecture,
+                             compile_params, fit_params, dist_key, dist_key_mapping,
+                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
+                             accessible_gpus_for_seg, prev_serialized_weights,
+                             is_final_training_call, custom_function_map=None, **kwargs):
+    """
+    This transition function is called when caching is called for
+    madlib_keras_fit_multiple_model().
+    The input params: dependent_var, independent_var are passed in
+    as None and dependent_var_shape, independent_var_shape as [0]
+    for all hops except the very firt hop
+    Some things to note in this function are:
+    - prev_serialized_weights can be passed in as None for the
+      very first hop and the final training call
+    - x_train, y_train and cache_set is cleared from SD for
+      final_training_call = TRUE
+    """
+    if not state:
+        agg_image_count = 0
+    else:
+        agg_image_count = float(state)
+
+    SD = kwargs['SD']
+    is_cache_set = 'cache_set' in SD
+
+    # Prepare the data
+    if is_cache_set:
+        if 'x_train' not in SD or 'y_train' not in SD:
+            plpy.error("cache not populated properly.")
+        total_images = None
+        is_last_row = True
+    else:
+        if not independent_var or not dependent_var:
+            return state
+        if 'x_train' not in SD:
+            SD['x_train'] = list()
+            SD['y_train'] = list()
+        agg_image_count += independent_var_shape[0]
+        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+                                                          images_per_seg)
+        is_last_row = agg_image_count == total_images
+        if is_last_row:
+            SD['cache_set'] = True
+        x_train_current = np_array_float32(independent_var, independent_var_shape)
+        y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+        SD['x_train'].append(x_train_current)
+        SD['y_train'].append(y_train_current)
+
+    # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
+    # But if the weights are None, we do not want to set any model. So early return in that case
+    if prev_serialized_weights is None:
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+        return float(agg_image_count)
+
+    segment_model = None
+    if is_last_row:
+        device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
+        segment_model, sess = get_init_model_and_sess(SD, device_name,
+                                                      accessible_gpus_for_seg[current_seg_id],
+                                                      segments_per_host,
+                                                      model_architecture, compile_params,
+                                                      custom_function_map)
+        set_model_weights(segment_model, prev_serialized_weights)
+
+        fit_params = parse_and_validate_fit_params(fit_params)
+        for i in range(len(SD['x_train'])):
+            # Fit segment model on data
+            segment_model.fit(SD['x_train'][i], SD['y_train'][i], **fit_params)
+
+
+    return_state = get_state_to_return(segment_model, is_last_row, True,
+                                       agg_image_count, total_images)
+
+    if is_last_row:
+        SD_STORE.clear_SD(SD)
+        clear_keras_session(sess)
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+
+    return return_state
+
 def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image_count,
                         total_images):
     """
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index b847550..c821474 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -81,7 +81,7 @@ class FitMultipleModel():
                  model_selection_table, num_iterations,
                  use_gpus=False, validation_table=None,
                  metrics_compute_frequency=None, warm_start=False, name="",
-                 description="", **kwargs):
+                 description="", use_caching=False, **kwargs):
         # set the random seed for visit order/scheduling
         random.seed(1)
         if is_platform_pg():
@@ -97,6 +97,7 @@ class FitMultipleModel():
         self.metrics_compute_frequency = metrics_compute_frequency
         self.name = name
         self.description = description
+        self.use_caching = use_caching if use_caching is not None else False
         self.module_name = 'madlib_keras_fit_multiple_model'
         self.schema_madlib = schema_madlib
         self.version = madlib_version(self.schema_madlib)
@@ -115,6 +116,7 @@ class FitMultipleModel():
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
         self.use_gpus = use_gpus
         self.segments_per_host = get_segments_per_host()
+        self.cached_source_table = unique_string('cached_source_table')
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                 self.schema_madlib, self.segments_per_host, self.module_name)
@@ -233,7 +235,7 @@ class FitMultipleModel():
                 self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1)
                 if mst_idx == 0:
                     start_iteration = time.time()
-                self.run_training(mst_idx)
+                self.run_training(mst_idx, mst_idx==0 and iter==1)
                 if mst_idx == (total_msts - 1):
                     end_iteration = time.time()
                     self.info_str = "\tTime for training in iteration " \
@@ -249,6 +251,7 @@ class FitMultipleModel():
                 if self.validation_table:
                     self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
+        plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self))
 
     def evaluate_model(self, epoch, table, is_train):
         if is_train:
@@ -594,7 +597,7 @@ class FitMultipleModel():
             if self.validation_table:
                 self.update_info_table(mst, False)
 
-    def run_training(self, mst_idx):
+    def run_training(self, mst_idx, is_very_first_hop):
         # NOTE: In the DL module, we want to avoid CREATING TEMP tables
         # (creates a slice which stays until the session is disconnected)
         # or minimize writing queries that generate plans with Motions (creating
@@ -622,12 +625,39 @@ class FitMultipleModel():
                    **locals())
         plpy.execute(mst_weights_query)
         use_gpus = self.use_gpus if self.use_gpus else False
+        dep_shape_col = self.dep_shape_col
+        ind_shape_col = self.ind_shape_col
+        dep_var = mb_dep_var_col
+        indep_var = mb_indep_var_col
+        source_table = self.source_table
+        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+        if self.use_caching:
+            # Caching populates the independent_var and dependent_var into the cache on the very first hop
+            # For the very_first_hop, we want to run the transition function on all segments, including
+            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
+            # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
+            # table to join for referencing the dist_key
+            if is_very_first_hop:
+                plpy.execute("""
+                    DROP TABLE IF EXISTS {self.cached_source_table};
+                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
+                    """.format(self=self, dist_key_col=dist_key_col))
+            else:
+                dep_shape_col = 'ARRAY[0]'
+                ind_shape_col = 'ARRAY[0]'
+                dep_var = 'NULL'
+                indep_var = 'NULL'
+                source_table = self.cached_source_table
+            if is_very_first_hop or self.is_final_training_call:
+                where_clause = ""
+
         uda_query = """
             CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
             SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
                 {mb_indep_var_col},
-                {self.dep_shape_col},
-                {self.ind_shape_col},
+                {dep_shape_col},
+                {ind_shape_col},
                 {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
                 {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
                 {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
@@ -639,21 +669,27 @@ class FitMultipleModel():
                 {use_gpus}::BOOLEAN,
                 ARRAY{self.accessible_gpus_for_seg},
                 {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_iteration}::BOOLEAN,
+                {is_final_training_call}::BOOLEAN,
+                {use_caching}::BOOLEAN,
                 {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
                 )::BYTEA AS {self.model_weights_col},
                 {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
                 ,src.{dist_key_col} AS {dist_key_col}
-            FROM {self.source_table} src JOIN {self.mst_weights_tbl}
+            FROM {source_table} src JOIN {self.mst_weights_tbl}
                 USING ({dist_key_col})
-            WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
+            {where_clause}
             GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
             DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=mb_dep_var_col,
-                       mb_indep_var_col=mb_indep_var_col,
-                       is_final_iteration=True,
+            """.format(mb_dep_var_col=dep_var,
+                       mb_indep_var_col=indep_var,
+                       dep_shape_col=dep_shape_col,
+                       ind_shape_col=ind_shape_col,
+                       is_final_training_call=self.is_final_training_call,
+                       use_caching=self.use_caching,
                        dist_key_col=dist_key_col,
                        use_gpus=use_gpus,
+                       source_table=source_table,
+                       where_clause=where_clause,
                        self=self
                        )
         plpy.execute(uda_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 392a3be..5a50733 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1416,7 +1416,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
     name                    VARCHAR,
-    description             VARCHAR
+    description             VARCHAR,
+    use_caching             BOOLEAN DEFAULT FALSE
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     from utilities.control import SetGUC
@@ -1506,13 +1507,17 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
     use_gpus                   BOOLEAN,
-    accessible_gpus_for_seg               INTEGER[],
+    accessible_gpus_for_seg    INTEGER[],
     prev_serialized_weights    BYTEA,
-    is_final_iteration         BOOLEAN,
+    is_final_training_call     BOOLEAN,
+    use_caching                BOOLEAN,
     custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+    if use_caching:
+        return madlib_keras.fit_multiple_transition(**globals())
+    else:
+        return madlib_keras.fit_transition(is_final_iteration = True, is_multiple_model = True, **globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1533,6 +1538,7 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step_multiple_model(
     INTEGER[],
     BYTEA,
     BOOLEAN,
+    BOOLEAN,
     BYTEA);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* dependent_var */              BYTEA,
@@ -1550,7 +1556,8 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* use_gpus */                   BOOLEAN,
     /* accessible_gpus_for_seg */    INTEGER[],
     /* prev_serialized_weights */    BYTEA,
-    /* is_final_iteration */         BOOLEAN,
+    /* is_final_training_call */     BOOLEAN,
+    /* use_caching */                BOOLEAN,
     /* custom_function_obj_map */    BYTEA
 )(
     STYPE=BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 82b2647..31bd9d6 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -375,6 +375,39 @@ SELECT assert(
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_one_hot_encoded_packed',
+	'iris_multiple_model',
+	'mst_table_4row',
+	3,
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+	TRUE
+);
+
+SELECT assert(
+        model_arch_table = 'iris_model_arch' AND
+        validation_table is NULL AND
+        model_info = 'iris_multiple_model_info' AND
+        source_table = 'iris_data_one_hot_encoded_packed' AND
+        model = 'iris_multiple_model' AND
+        model_selection_table = 'mst_table_4row' AND
+        object_table IS NULL AND
+        dependent_varname = 'class_one_hot_encoded' AND
+        independent_varname = 'attributes' AND
+        madlib_version is NOT NULL AND
+        num_iterations = 3 AND
+        start_training_time < now() AND
+        end_training_time < now() AND
+        dependent_vartype = 'integer[]' AND
+        num_classes = NULL AND
+        class_values = NULL AND
+        normalizing_const = 1 AND
+        metrics_iters = ARRAY[3],
+        'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
 -- Test the output table created are all persistent(not unlogged)
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
@@ -495,6 +528,85 @@ SELECT assert(
 FROM iris_multiple_model_info
 WHERE compile_params like '%lr=0.001%';
 
+-- Testing with caching
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table',
+	6,
+	FALSE,
+	'iris_data_one_hot_encoded_packed', NULL, NULL, NULL, NULL,
+	TRUE
+);
+
+SELECT assert(
+        source_table = 'iris_data_packed' AND
+        validation_table = 'iris_data_one_hot_encoded_packed' AND
+        model = 'iris_multiple_model' AND
+        model_info = 'iris_multiple_model_info' AND
+        dependent_varname = 'class_text' AND
+        independent_varname = 'attributes' AND
+        model_arch_table = 'iris_model_arch' AND
+        num_iterations = 6 AND
+        start_training_time < now() AND
+        end_training_time < now() AND
+        madlib_version is NOT NULL AND
+        num_classes = 3 AND
+        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+        dependent_vartype LIKE '%char%' AND
+        normalizing_const = 1 AND
+        name IS NULL AND
+        description IS NULL AND
+        metrics_compute_frequency = 6,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        fit_params = $MAD$batch_size=50, epochs=1$MAD$::text AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
+        validation_metrics_final >= 0  AND
+        validation_loss_final  >= 0  AND
+        array_upper(validation_metrics, 1) = 1 AND
+        array_upper(validation_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(
+  training_loss[6]-training_loss[1] < 0.1 AND
+  training_metrics[6]-training_metrics[1] > -0.1,
+    'The loss and accuracy should have improved with more iterations.'
+)
+FROM iris_multiple_model_info
+WHERE compile_params like '%lr=0.001%';
+
 -- Test when number of configs(1) is less than number of segments(3)
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
 SELECT madlib_keras_fit_multiple_model(
@@ -543,6 +655,55 @@ SELECT assert(cnt = 1,
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
 
+-- Testing with caching configs(1) is less than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table_1row',
+	3,
+	FALSE,
+	NULL,
+	1,
+	FALSE,
+	'multi_model_name',
+	'multi_model_descr',
+	TRUE
+);
+
+SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 3 AND
+        array_upper(training_loss, 1) = 3 AND
+        array_upper(metrics_elapsed_time, 1) = 3,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
+        'Keras Fit Multiple invalid elapsed time calculation.')
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(
+        name = 'multi_model_name' AND
+        description = 'multi_model_descr' AND
+        metrics_compute_frequency = 1,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+
 -- Test when number of configs(4) larger than number of segments(3)
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
 SELECT madlib_keras_fit_multiple_model(
@@ -580,6 +741,44 @@ FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
 AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
 
+-- Test with caching when number of configs(4) larger than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table_4row',
+	3,
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+	TRUE
+);
+
+-- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
+-- but we change it to 'on' in fit_multiple.py. Assert that the value is
+-- reset after calling fit_multiple
+SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
+
+SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
+AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+
 -- Test when class values have NULL values
 UPDATE iris_data_packed_summary SET class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL];
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
@@ -606,7 +805,6 @@ CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed as select
 CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed_summary as select * from iris_data_packed_summary;
 
 -- do not drop the output table created in the previous test
-SELECT count(*) from iris_multiple_model;
 SELECT madlib_keras_fit_multiple_model(
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed',
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 6dacdcd..9990adc 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -145,7 +145,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
-    def test_fit_transition_multiple_model_first_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_first_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -172,6 +172,36 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
+    def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+        starting_image_count = 0
+        ending_image_count = len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+
+        k = {'SD': {}}
+
+        new_state = self.subject.fit_multiple_transition(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session should only be called for the last row
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session must not be called for the first buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -228,7 +258,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_middle_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -259,6 +289,41 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session is only called for the last buffer
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the middle buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -327,7 +392,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         #  but not in postgres
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_last_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_last_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -362,6 +427,137 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         #  but not in postgres
         self.assertEqual(1, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_final_training_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue('x_train' not in k['SD'])
+        self.assertTrue('y_train' not in k['SD'])
+
     def test_fit_transition_first_buffer_pass_pg(self):
         self._test_fit_transition_first_buffer_pass(True)