You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by do...@apache.org on 2020/09/29 18:12:46 UTC

[madlib] branch master updated (3cb2305 -> 49262a5)

This is an automated email from the ASF dual-hosted git repository.

domino pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git.


 discard 3cb2305  add use_caching param descr and examples to user docs
 discard b20119d  Convert EXECUTE to PERFORM
 discard fe58ef5  Address review comments
 discard a119fe8  DL: Implement caching for fit_multiple_model
     new 49262a5  DL: Implement caching for fit_multiple_model

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (3cb2305)
            \
             N -- N -- N   refs/heads/master (49262a5)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:


[madlib] 01/01: DL: Implement caching for fit_multiple_model

Posted by do...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 49262a5bbf95be941f0455c921983dcb5b9632ef
Author: Ekta Khanna <ek...@vmware.com>
AuthorDate: Thu Sep 10 12:31:28 2020 -0700

    DL: Implement caching for fit_multiple_model
    
    Currently passing around independent and dependent vars to the
    transition function is what takes up most of the time.
    As part of this commit, add a new fit_multipl_transition function that
    reads all the rows (for each seg) into the cache(SD) for the very first
    hop and for each subsequent hop/iteration, the data is read from the
    cache instead of table and cleared out at the final training call. This
    helps reduces the time to pass along the data to the transition function.
    Since, the data is cached into memory, the memory usage per segment
    increases significantly. To avoid this, a new optional param
    `use_caching` is added to madlib_keras_fit_multiple_model(), that can be
    set to TRUE if the memory on each segment meets the following
    calculation:
    
       IND_SZ (indep var size of each row) = ((image_dimension)*4)*(#of images per buffer)
       DEP_SZ (indep var size of each row) = (#DEP_VAR * 4)*(#of images per buffer)
       memory_data = (#seg_per_host) * (#rows_per_seg * IND_SZ) + (#seg_per_host) * (#rows_per_seg * DEP_SZ)
       memory_model = model_size * #models_per_seg * #seg_per_host
       total_memory = memory_data + memory_model
    
    Also:
    - use_caching param descr and examples added to user docs
    - Run each fit multiple dev-check test once for non-cached and once for cached case
---
 .../modules/deep_learning/madlib_keras.py_in       |  89 ++++++++-
 .../madlib_keras_fit_multiple_model.py_in          |  58 ++++--
 .../madlib_keras_fit_multiple_model.sql_in         |  50 +++--
 .../test/madlib_keras_model_selection.sql_in       | 132 ++++++++++----
 .../test/unit_tests/test_madlib_keras.py_in        | 202 ++++++++++++++++++++-
 5 files changed, 461 insertions(+), 70 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index e8eac71..0d55028 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -523,7 +523,7 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
-                                  agg_image_count, total_images)
+                                       agg_image_count, total_images)
     if is_last_row:
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
@@ -531,6 +531,93 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
 
     return return_state
 
+def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
+                             independent_var_shape, model_architecture,
+                             compile_params, fit_params, dist_key, dist_key_mapping,
+                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
+                             accessible_gpus_for_seg, prev_serialized_weights,
+                             is_final_training_call, custom_function_map=None, **kwargs):
+    """
+    This transition function is called when caching is called for
+    madlib_keras_fit_multiple_model().
+    The input params: dependent_var, independent_var are passed in
+    as None and dependent_var_shape, independent_var_shape as [0]
+    for all hops except the very first hop
+    Some things to note in this function are:
+    - prev_serialized_weights can be passed in as None for the
+      very first hop and the final training call
+    - x_train, y_train and cache_set is cleared from SD for
+      final_training_call = TRUE
+    """
+    if not state:
+        agg_image_count = 0
+    else:
+        agg_image_count = float(state)
+
+    SD = kwargs['SD']
+    is_cache_set = 'cache_set' in SD
+
+    # Prepare the data
+    if is_cache_set:
+        if 'x_train' not in SD or 'y_train' not in SD:
+            plpy.error("cache not populated properly.")
+        total_images = None
+        is_last_row = True
+    else:
+        if not independent_var or not dependent_var:
+            return state
+        if 'x_train' not in SD:
+            SD['x_train'] = list()
+            SD['y_train'] = list()
+        agg_image_count += independent_var_shape[0]
+        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+                                                          images_per_seg)
+        is_last_row = agg_image_count == total_images
+        if is_last_row:
+            SD['cache_set'] = True
+        x_train_current = np_array_float32(independent_var, independent_var_shape)
+        y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+        SD['x_train'].append(x_train_current)
+        SD['y_train'].append(y_train_current)
+
+    # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
+    # But if the weights are None, we do not want to set any model. So early return in that case
+    if prev_serialized_weights is None:
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+        return float(agg_image_count)
+
+    segment_model = None
+    if is_last_row:
+        device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
+        segment_model, sess = get_init_model_and_sess(SD, device_name,
+                                                      accessible_gpus_for_seg[current_seg_id],
+                                                      segments_per_host,
+                                                      model_architecture, compile_params,
+                                                      custom_function_map)
+        set_model_weights(segment_model, prev_serialized_weights)
+
+        fit_params = parse_and_validate_fit_params(fit_params)
+        for i in range(len(SD['x_train'])):
+            # Fit segment model on data
+            segment_model.fit(SD['x_train'][i], SD['y_train'][i], **fit_params)
+
+
+    return_state = get_state_to_return(segment_model, is_last_row, True,
+                                       agg_image_count, total_images)
+
+    if is_last_row:
+        SD_STORE.clear_SD(SD)
+        clear_keras_session(sess)
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+
+    return return_state
+
 def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image_count,
                         total_images):
     """
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index b847550..c821474 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -81,7 +81,7 @@ class FitMultipleModel():
                  model_selection_table, num_iterations,
                  use_gpus=False, validation_table=None,
                  metrics_compute_frequency=None, warm_start=False, name="",
-                 description="", **kwargs):
+                 description="", use_caching=False, **kwargs):
         # set the random seed for visit order/scheduling
         random.seed(1)
         if is_platform_pg():
@@ -97,6 +97,7 @@ class FitMultipleModel():
         self.metrics_compute_frequency = metrics_compute_frequency
         self.name = name
         self.description = description
+        self.use_caching = use_caching if use_caching is not None else False
         self.module_name = 'madlib_keras_fit_multiple_model'
         self.schema_madlib = schema_madlib
         self.version = madlib_version(self.schema_madlib)
@@ -115,6 +116,7 @@ class FitMultipleModel():
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
         self.use_gpus = use_gpus
         self.segments_per_host = get_segments_per_host()
+        self.cached_source_table = unique_string('cached_source_table')
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                 self.schema_madlib, self.segments_per_host, self.module_name)
@@ -233,7 +235,7 @@ class FitMultipleModel():
                 self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1)
                 if mst_idx == 0:
                     start_iteration = time.time()
-                self.run_training(mst_idx)
+                self.run_training(mst_idx, mst_idx==0 and iter==1)
                 if mst_idx == (total_msts - 1):
                     end_iteration = time.time()
                     self.info_str = "\tTime for training in iteration " \
@@ -249,6 +251,7 @@ class FitMultipleModel():
                 if self.validation_table:
                     self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
+        plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self))
 
     def evaluate_model(self, epoch, table, is_train):
         if is_train:
@@ -594,7 +597,7 @@ class FitMultipleModel():
             if self.validation_table:
                 self.update_info_table(mst, False)
 
-    def run_training(self, mst_idx):
+    def run_training(self, mst_idx, is_very_first_hop):
         # NOTE: In the DL module, we want to avoid CREATING TEMP tables
         # (creates a slice which stays until the session is disconnected)
         # or minimize writing queries that generate plans with Motions (creating
@@ -622,12 +625,39 @@ class FitMultipleModel():
                    **locals())
         plpy.execute(mst_weights_query)
         use_gpus = self.use_gpus if self.use_gpus else False
+        dep_shape_col = self.dep_shape_col
+        ind_shape_col = self.ind_shape_col
+        dep_var = mb_dep_var_col
+        indep_var = mb_indep_var_col
+        source_table = self.source_table
+        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+        if self.use_caching:
+            # Caching populates the independent_var and dependent_var into the cache on the very first hop
+            # For the very_first_hop, we want to run the transition function on all segments, including
+            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
+            # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
+            # table to join for referencing the dist_key
+            if is_very_first_hop:
+                plpy.execute("""
+                    DROP TABLE IF EXISTS {self.cached_source_table};
+                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
+                    """.format(self=self, dist_key_col=dist_key_col))
+            else:
+                dep_shape_col = 'ARRAY[0]'
+                ind_shape_col = 'ARRAY[0]'
+                dep_var = 'NULL'
+                indep_var = 'NULL'
+                source_table = self.cached_source_table
+            if is_very_first_hop or self.is_final_training_call:
+                where_clause = ""
+
         uda_query = """
             CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
             SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
                 {mb_indep_var_col},
-                {self.dep_shape_col},
-                {self.ind_shape_col},
+                {dep_shape_col},
+                {ind_shape_col},
                 {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
                 {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
                 {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
@@ -639,21 +669,27 @@ class FitMultipleModel():
                 {use_gpus}::BOOLEAN,
                 ARRAY{self.accessible_gpus_for_seg},
                 {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_iteration}::BOOLEAN,
+                {is_final_training_call}::BOOLEAN,
+                {use_caching}::BOOLEAN,
                 {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
                 )::BYTEA AS {self.model_weights_col},
                 {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
                 ,src.{dist_key_col} AS {dist_key_col}
-            FROM {self.source_table} src JOIN {self.mst_weights_tbl}
+            FROM {source_table} src JOIN {self.mst_weights_tbl}
                 USING ({dist_key_col})
-            WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
+            {where_clause}
             GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
             DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=mb_dep_var_col,
-                       mb_indep_var_col=mb_indep_var_col,
-                       is_final_iteration=True,
+            """.format(mb_dep_var_col=dep_var,
+                       mb_indep_var_col=indep_var,
+                       dep_shape_col=dep_shape_col,
+                       ind_shape_col=ind_shape_col,
+                       is_final_training_call=self.is_final_training_call,
+                       use_caching=self.use_caching,
                        dist_key_col=dist_key_col,
                        use_gpus=use_gpus,
+                       source_table=source_table,
+                       where_clause=where_clause,
                        self=self
                        )
         plpy.execute(uda_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 392a3be..5b72672 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -88,14 +88,14 @@ You can set up the models and hyperparameters to try with the
 Model Selection</a> utility to define the unique combinations
 of model architectures, compile and fit parameters.
 
-@note If 'madlib_keras_fit_multiple_model()' is running on GPDB 5 and some versions
+@note 1. If 'madlib_keras_fit_multiple_model()' is running on GPDB 5 and some versions
 of GPDB 6, the database will
 keep adding to the disk space (in proportion to model size) and will only
 release the disk space once the fit multiple query has completed execution.
 This is not the case for GPDB 6.5.0+ where disk space is released during the
 fit multiple query.
 
-@note CUDA GPU memory cannot be released until the process holding it is terminated.
+@note 2. CUDA GPU memory cannot be released until the process holding it is terminated.
 When a MADlib deep learning function is called with GPUs, Greenplum internally
 creates a process (called a slice) which calls TensorFlow to do the computation.
 This process holds the GPU memory until one of the following two things happen:
@@ -121,7 +121,8 @@ madlib_keras_fit_multiple_model(
     metrics_compute_frequency,
     warm_start,
     name,
-    description
+    description,
+    use_caching
     )
 </pre>
 
@@ -231,6 +232,17 @@ madlib_keras_fit_multiple_model(
   <DD>TEXT, default: NULL.
     Free text string to provide a description, if desired.
   </DD>
+
+  <DT>use_caching (optional)</DT>
+  <DD>BOOLEAN, default: FALSE. Use caching of images in memory on the 
+  segment in order to speed up processing. 
+
+  @note
+  When set to TRUE, image byte arrays on each segment are maintained 
+  in cache (SD). This can speed up training significantly, however the 
+  memory usage per segment increases.  In effect, it 
+  requires enough available memory on a segment so that all images 
+  residing on that segment can be read into memory.
 </dl>
 
 <b>Output tables</b>
@@ -1155,7 +1167,7 @@ WHERE q.actual=q.estimated;
 and compute metrics every 3rd iteration using
 the 'metrics_compute_frequency' parameter. This can
 help reduce run time if you do not need metrics
-computed at every iteration.
+computed at every iteration.  Also turn on image caching.
 <pre class="example">
 DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;
 SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_table
@@ -1167,7 +1179,8 @@ SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_
                                                3,                     -- metrics compute frequency
                                                FALSE,                 -- warm start
                                               'Sophie L.',            -- name
-                                              'Model selection for iris dataset'  -- description
+                                              'Model selection for iris dataset',  -- description
+                                               TRUE                   -- use caching
                                              );
 </pre>
 View the model summary:
@@ -1282,7 +1295,8 @@ SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_
                                                1,                     -- metrics compute frequency
                                                TRUE,                  -- warm start
                                               'Sophie L.',            -- name
-                                              'Simple MLP for iris dataset'  -- description
+                                              'Simple MLP for iris dataset',  -- description
+                                               TRUE                   -- use caching
                                              );
 SELECT * FROM iris_multi_model_summary;
 </pre>
@@ -1380,10 +1394,9 @@ inference runtimes will be proportionally faster as the number of segments incre
 Supun Nakandala, Yuhao Zhang, and Arun Kumar, ACM SIGMOD 2019 DEEM Workshop,
 https://adalabucsd.github.io/papers/2019_Cerebro_DEEM.pdf
 
-[2] "Resource-Efficient and Reproducible Model Selection on Deep Learning Systems,"
-Supun Nakandala, Yuhao Zhang, and Arun Kumar, Technical Report, Computer Science and
-Engineering, University of California, San Diego
-https://adalabucsd.github.io/papers/TR_2019_Cerebro.pdf
+[2] "Cerebro: A Data System for Optimized Deep Learning Model Selection,"
+Supun Nakandala, Yuhao Zhang, and Arun Kumar, Proceedings of the VLDB Endowment (2020), Vol. 13, No. 11
+https://adalabucsd.github.io/papers/2020_Cerebro_VLDB.pdf
 
 [3] https://keras.io/
 
@@ -1416,7 +1429,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
     name                    VARCHAR,
-    description             VARCHAR
+    description             VARCHAR,
+    use_caching             BOOLEAN DEFAULT FALSE
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     from utilities.control import SetGUC
@@ -1506,13 +1520,17 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
     use_gpus                   BOOLEAN,
-    accessible_gpus_for_seg               INTEGER[],
+    accessible_gpus_for_seg    INTEGER[],
     prev_serialized_weights    BYTEA,
-    is_final_iteration         BOOLEAN,
+    is_final_training_call     BOOLEAN,
+    use_caching                BOOLEAN,
     custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+    if use_caching:
+        return madlib_keras.fit_multiple_transition_caching(**globals())
+    else:
+        return madlib_keras.fit_transition(is_final_iteration = True, is_multiple_model = True, **globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1533,6 +1551,7 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step_multiple_model(
     INTEGER[],
     BYTEA,
     BOOLEAN,
+    BOOLEAN,
     BYTEA);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* dependent_var */              BYTEA,
@@ -1550,7 +1569,8 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* use_gpus */                   BOOLEAN,
     /* accessible_gpus_for_seg */    INTEGER[],
     /* prev_serialized_weights */    BYTEA,
-    /* is_final_iteration */         BOOLEAN,
+    /* is_final_training_call */     BOOLEAN,
+    /* use_caching */                BOOLEAN,
     /* custom_function_obj_map */    BYTEA
 )(
     STYPE=BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 82b2647..0c29246 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -344,16 +344,20 @@ SELECT load_model_selection_table(
 );
 
 -- Test for one-hot encoded input data
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_one_hot_encoded_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
-	3,
-	FALSE
+CREATE OR REPLACE FUNCTION test_fit_multiple_one_hot_encoded_input(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+PERFORM madlib_keras_fit_multiple_model(
+        'iris_data_one_hot_encoded_packed'::VARCHAR,
+        'iris_multiple_model'::VARCHAR,
+        'mst_table_4row'::VARCHAR,
+        3,
+        FALSE, NULL, NULL, NULL, NULL, NULL,
+        caching
 );
 
-SELECT assert(
+PERFORM assert(
         model_arch_table = 'iris_model_arch' AND
         validation_table is NULL AND
         model_info = 'iris_multiple_model_info' AND
@@ -365,8 +369,7 @@ SELECT assert(
         independent_varname = 'attributes' AND
         madlib_version is NOT NULL AND
         num_iterations = 3 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         dependent_vartype = 'integer[]' AND
         num_classes = NULL AND
         class_values = NULL AND
@@ -374,6 +377,15 @@ SELECT assert(
         metrics_iters = ARRAY[3],
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
+END;
+$$ language plpgsql VOLATILE;
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_one_hot_encoded_input(FALSE);
+
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_one_hot_encoded_input(TRUE);
 
 -- Test the output table created are all persistent(not unlogged)
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
@@ -418,18 +430,23 @@ SELECT assert(
 FROM (SELECT * FROM mst_object_table_summary) summary;
 
 -- Test when number of configs(3) equals number of segments(3)
-DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
+CREATE OR REPLACE FUNCTION test_fit_multiple_equal_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+
+PERFORM setseed(0);
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table',
 	6,
 	FALSE,
-	'iris_data_one_hot_encoded_packed'
+	'iris_data_one_hot_encoded_packed', NULL, NULL, NULL, NULL,
+	caching
 );
 
-SELECT assert(
+PERFORM assert(
         source_table = 'iris_data_packed' AND
         validation_table = 'iris_data_one_hot_encoded_packed' AND
         model = 'iris_multiple_model' AND
@@ -438,8 +455,7 @@ SELECT assert(
         independent_varname = 'attributes' AND
         model_arch_table = 'iris_model_arch' AND
         num_iterations = 6 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         madlib_version is NOT NULL AND
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
@@ -451,10 +467,10 @@ SELECT assert(
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -470,34 +486,47 @@ SELECT assert(
         array_upper(validation_loss, 1) = 1 AND
         array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
+FROM (SELECT * FROM iris_multiple_model_info limit 1) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(
+PERFORM assert(
   training_loss[6]-training_loss[1] < 0.1 AND
   training_metrics[6]-training_metrics[1] > -0.1,
     'The loss and accuracy should have improved with more iterations.'
 )
 FROM iris_multiple_model_info
 WHERE compile_params like '%lr=0.001%';
+END;
+$$ LANGUAGE plpgsql;
+
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(FALSE);
+
+-- Testing with caching
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(TRUE);
 
 -- Test when number of configs(1) is less than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+CREATE OR REPLACE FUNCTION test_fit_multiple_less_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_1row',
@@ -507,13 +536,14 @@ SELECT madlib_keras_fit_multiple_model(
 	1,
 	FALSE,
 	'multi_model_name',
-	'multi_model_descr'
+	'multi_model_descr',
+	caching
 );
 
-SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -527,41 +557,55 @@ SELECT assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
+PERFORM assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
         'Keras Fit Multiple invalid elapsed time calculation.')
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(
+PERFORM assert(
         name = 'multi_model_name' AND
         description = 'multi_model_descr' AND
         metrics_compute_frequency = 1,
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
 
--- Test when number of configs(4) larger than number of segments(3)
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+SELECT test_fit_multiple_less_configs(FALSE);
+
+-- Testing with caching configs(1) is less than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_less_configs(TRUE);
+
+-- Test when number of configs(4) larger than number of segments(3)
+CREATE OR REPLACE FUNCTION test_fit_multiple_more_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_4row',
 	3,
-	FALSE
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+	caching
 );
 
 -- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
 -- but we change it to 'on' in fit_multiple.py. Assert that the value is
 -- reset after calling fit_multiple
-SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
+PERFORM CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
 
-SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -574,11 +618,20 @@ SELECT assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
 AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(FALSE);
+
+-- Test with caching when number of configs(4) larger than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(TRUE);
 
 -- Test when class values have NULL values
 UPDATE iris_data_packed_summary SET class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL];
@@ -606,7 +659,6 @@ CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed as select
 CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed_summary as select * from iris_data_packed_summary;
 
 -- do not drop the output table created in the previous test
-SELECT count(*) from iris_multiple_model;
 SELECT madlib_keras_fit_multiple_model(
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed',
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 6dacdcd..4ccf2bd 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -145,7 +145,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
-    def test_fit_transition_multiple_model_first_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_first_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -172,6 +172,36 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
+    def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+        starting_image_count = 0
+        ending_image_count = len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+
+        k = {'SD': {}}
+
+        new_state = self.subject.fit_multiple_transition_caching(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session should only be called for the last row
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session must not be called for the first buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -228,7 +258,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_middle_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -259,6 +289,41 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition_caching(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session is only called for the last buffer
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the middle buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -327,7 +392,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         #  but not in postgres
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_last_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_last_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -362,6 +427,137 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         #  but not in postgres
         self.assertEqual(1, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition_caching(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition_caching(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_final_training_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition_caching(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue('x_train' not in k['SD'])
+        self.assertTrue('y_train' not in k['SD'])
+
     def test_fit_transition_first_buffer_pass_pg(self):
         self._test_fit_transition_first_buffer_pass(True)