You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by do...@apache.org on 2020/09/29 18:08:01 UTC

[madlib] branch master updated (ba6c836 -> 3cb2305)

This is an automated email from the ASF dual-hosted git repository.

domino pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git.


    from ba6c836  DL: [AutoML] Set plan_cache_mode when calling fit multiple model
     new a119fe8  DL: Implement caching for fit_multiple_model
     new fe58ef5  Address review comments
     new b20119d  Convert EXECUTE to PERFORM
     new 3cb2305  add use_caching param descr and examples to user docs

The 4 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../modules/deep_learning/madlib_keras.py_in       |  89 ++++++++-
 .../madlib_keras_fit_multiple_model.py_in          |  58 ++++--
 .../madlib_keras_fit_multiple_model.sql_in         |  50 +++--
 .../test/madlib_keras_model_selection.sql_in       | 132 ++++++++++----
 .../test/unit_tests/test_madlib_keras.py_in        | 202 ++++++++++++++++++++-
 5 files changed, 461 insertions(+), 70 deletions(-)


[madlib] 03/04: Convert EXECUTE to PERFORM

Posted by do...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit b20119dd1bc21a221db843c0288853f257f6e610
Author: Domino Valdano <dv...@vmware.com>
AuthorDate: Thu Sep 24 20:00:39 2020 -0700

    Convert EXECUTE to PERFORM
    
    EXECUTE is supposed to act on a string, PERFORM on a query.
    I think the EXECUTE was only working because the call to
    assert() returns NULL, which I guess is treated as ''
    (so it runs SELECT '').  I'm not sure it matters in this
    case, but this way is more straightforward.
---
 .../test/madlib_keras_model_selection.sql_in       | 52 +++++++++++-----------
 1 file changed, 26 insertions(+), 26 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index d39f2a0..0c29246 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -348,16 +348,16 @@ CREATE OR REPLACE FUNCTION test_fit_multiple_one_hot_encoded_input(caching boole
 RETURNS VOID AS
 $$
 BEGIN
-EXECUTE madlib_keras_fit_multiple_model(
-	'iris_data_one_hot_encoded_packed'::VARCHAR,
-	'iris_multiple_model'::VARCHAR,
-	'mst_table_4row'::VARCHAR,
-	3,
-	FALSE, NULL, NULL, NULL, NULL, NULL,
-    caching
+PERFORM madlib_keras_fit_multiple_model(
+        'iris_data_one_hot_encoded_packed'::VARCHAR,
+        'iris_multiple_model'::VARCHAR,
+        'mst_table_4row'::VARCHAR,
+        3,
+        FALSE, NULL, NULL, NULL, NULL, NULL,
+        caching
 );
 
-EXECUTE assert(
+PERFORM assert(
         model_arch_table = 'iris_model_arch' AND
         validation_table is NULL AND
         model_info = 'iris_multiple_model_info' AND
@@ -378,7 +378,7 @@ EXECUTE assert(
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 END;
-$$ language plpgsql;
+$$ language plpgsql VOLATILE;
 
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
 SELECT test_fit_multiple_one_hot_encoded_input(FALSE);
@@ -435,8 +435,8 @@ RETURNS VOID AS
 $$
 BEGIN
 
-EXECUTE setseed(0);
-EXECUTE madlib_keras_fit_multiple_model(
+PERFORM setseed(0);
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table',
@@ -446,7 +446,7 @@ EXECUTE madlib_keras_fit_multiple_model(
 	caching
 );
 
-EXECUTE assert(
+PERFORM assert(
         source_table = 'iris_data_packed' AND
         validation_table = 'iris_data_one_hot_encoded_packed' AND
         model = 'iris_multiple_model' AND
@@ -467,7 +467,7 @@ EXECUTE assert(
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-EXECUTE assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
 PERFORM assert(
@@ -488,22 +488,22 @@ PERFORM assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info limit 1) info;
 
-EXECUTE assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
 
-EXECUTE assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
 
-EXECUTE assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
 
-EXECUTE assert(
+PERFORM assert(
   training_loss[6]-training_loss[1] < 0.1 AND
   training_metrics[6]-training_metrics[1] > -0.1,
     'The loss and accuracy should have improved with more iterations.'
@@ -526,7 +526,7 @@ RETURNS VOID AS
 $$
 BEGIN
 
-EXECUTE madlib_keras_fit_multiple_model(
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_1row',
@@ -540,10 +540,10 @@ EXECUTE madlib_keras_fit_multiple_model(
 	caching
 );
 
-EXECUTE assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-EXECUTE assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -557,18 +557,18 @@ EXECUTE assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-EXECUTE assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
+PERFORM assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
         'Keras Fit Multiple invalid elapsed time calculation.')
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-EXECUTE assert(
+PERFORM assert(
         name = 'multi_model_name' AND
         description = 'multi_model_descr' AND
         metrics_compute_frequency = 1,
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-EXECUTE assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
@@ -588,7 +588,7 @@ RETURNS VOID AS
 $$
 BEGIN
 
-EXECUTE madlib_keras_fit_multiple_model(
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_4row',
@@ -602,7 +602,7 @@ EXECUTE madlib_keras_fit_multiple_model(
 -- reset after calling fit_multiple
 PERFORM CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
 
-EXECUTE assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
 PERFORM assert(
@@ -618,7 +618,7 @@ PERFORM assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-EXECUTE assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text


[madlib] 02/04: Address review comments

Posted by do...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit fe58ef58302b54469002c18c793528ce8455356c
Author: Ekta Khanna <ek...@vmware.com>
AuthorDate: Tue Sep 22 18:14:05 2020 -0700

    Address review comments
---
 .../modules/deep_learning/madlib_keras.py_in       |   4 +-
 .../madlib_keras_fit_multiple_model.sql_in         |   2 +-
 .../test/madlib_keras_model_selection.sql_in       | 304 ++++++---------------
 .../test/unit_tests/test_madlib_keras.py_in        |  10 +-
 4 files changed, 87 insertions(+), 233 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 00889b9..0d55028 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -531,7 +531,7 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
 
     return return_state
 
-def fit_multiple_transition(state, dependent_var, independent_var, dependent_var_shape,
+def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
                              independent_var_shape, model_architecture,
                              compile_params, fit_params, dist_key, dist_key_mapping,
                              current_seg_id, segments_per_host, images_per_seg, use_gpus,
@@ -542,7 +542,7 @@ def fit_multiple_transition(state, dependent_var, independent_var, dependent_var
     madlib_keras_fit_multiple_model().
     The input params: dependent_var, independent_var are passed in
     as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very firt hop
+    for all hops except the very first hop
     Some things to note in this function are:
     - prev_serialized_weights can be passed in as None for the
       very first hop and the final training call
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 5a50733..1805eb7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1515,7 +1515,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     if use_caching:
-        return madlib_keras.fit_multiple_transition(**globals())
+        return madlib_keras.fit_multiple_transition_caching(**globals())
     else:
         return madlib_keras.fit_transition(is_final_iteration = True, is_multiple_model = True, **globals())
 $$ LANGUAGE plpythonu
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 31bd9d6..d39f2a0 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -344,16 +344,20 @@ SELECT load_model_selection_table(
 );
 
 -- Test for one-hot encoded input data
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_one_hot_encoded_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
+CREATE OR REPLACE FUNCTION test_fit_multiple_one_hot_encoded_input(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+EXECUTE madlib_keras_fit_multiple_model(
+	'iris_data_one_hot_encoded_packed'::VARCHAR,
+	'iris_multiple_model'::VARCHAR,
+	'mst_table_4row'::VARCHAR,
 	3,
-	FALSE
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+    caching
 );
 
-SELECT assert(
+EXECUTE assert(
         model_arch_table = 'iris_model_arch' AND
         validation_table is NULL AND
         model_info = 'iris_multiple_model_info' AND
@@ -365,8 +369,7 @@ SELECT assert(
         independent_varname = 'attributes' AND
         madlib_version is NOT NULL AND
         num_iterations = 3 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         dependent_vartype = 'integer[]' AND
         num_classes = NULL AND
         class_values = NULL AND
@@ -374,39 +377,15 @@ SELECT assert(
         metrics_iters = ARRAY[3],
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
+END;
+$$ language plpgsql;
 
--- Testing with caching
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_one_hot_encoded_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
-	3,
-	FALSE, NULL, NULL, NULL, NULL, NULL,
-	TRUE
-);
+SELECT test_fit_multiple_one_hot_encoded_input(FALSE);
 
-SELECT assert(
-        model_arch_table = 'iris_model_arch' AND
-        validation_table is NULL AND
-        model_info = 'iris_multiple_model_info' AND
-        source_table = 'iris_data_one_hot_encoded_packed' AND
-        model = 'iris_multiple_model' AND
-        model_selection_table = 'mst_table_4row' AND
-        object_table IS NULL AND
-        dependent_varname = 'class_one_hot_encoded' AND
-        independent_varname = 'attributes' AND
-        madlib_version is NOT NULL AND
-        num_iterations = 3 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
-        dependent_vartype = 'integer[]' AND
-        num_classes = NULL AND
-        class_values = NULL AND
-        normalizing_const = 1 AND
-        metrics_iters = ARRAY[3],
-        'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
-FROM (SELECT * FROM iris_multiple_model_summary) summary;
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_one_hot_encoded_input(TRUE);
 
 -- Test the output table created are all persistent(not unlogged)
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
@@ -451,97 +430,23 @@ SELECT assert(
 FROM (SELECT * FROM mst_object_table_summary) summary;
 
 -- Test when number of configs(3) equals number of segments(3)
-DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_packed',
-	'iris_multiple_model',
-	'mst_table',
-	6,
-	FALSE,
-	'iris_data_one_hot_encoded_packed'
-);
-
-SELECT assert(
-        source_table = 'iris_data_packed' AND
-        validation_table = 'iris_data_one_hot_encoded_packed' AND
-        model = 'iris_multiple_model' AND
-        model_info = 'iris_multiple_model_info' AND
-        dependent_varname = 'class_text' AND
-        independent_varname = 'attributes' AND
-        model_arch_table = 'iris_model_arch' AND
-        num_iterations = 6 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
-        madlib_version is NOT NULL AND
-        num_classes = 3 AND
-        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
-        dependent_vartype LIKE '%char%' AND
-        normalizing_const = 1 AND
-        name IS NULL AND
-        description IS NULL AND
-        metrics_compute_frequency = 6,
-        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
-FROM (SELECT * FROM iris_multiple_model_summary) summary;
-
-SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
-FROM iris_multiple_model_info;
-
-SELECT assert(
-        model_id = 1 AND
-        model_type = 'madlib_keras' AND
-        model_size > 0 AND
-        fit_params = $MAD$batch_size=50, epochs=1$MAD$::text AND
-        metrics_type = '{accuracy}' AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 1 AND
-        array_upper(training_loss, 1) = 1 AND
-        validation_metrics_final >= 0  AND
-        validation_loss_final  >= 0  AND
-        array_upper(validation_metrics, 1) = 1 AND
-        array_upper(validation_loss, 1) = 1 AND
-        array_upper(metrics_elapsed_time, 1) = 1,
-        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
-
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
-
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
-
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
-
-SELECT assert(
-  training_loss[6]-training_loss[1] < 0.1 AND
-  training_metrics[6]-training_metrics[1] > -0.1,
-    'The loss and accuracy should have improved with more iterations.'
-)
-FROM iris_multiple_model_info
-WHERE compile_params like '%lr=0.001%';
+CREATE OR REPLACE FUNCTION test_fit_multiple_equal_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
 
--- Testing with caching
-DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
+EXECUTE setseed(0);
+EXECUTE madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table',
 	6,
 	FALSE,
 	'iris_data_one_hot_encoded_packed', NULL, NULL, NULL, NULL,
-	TRUE
+	caching
 );
 
-SELECT assert(
+EXECUTE assert(
         source_table = 'iris_data_packed' AND
         validation_table = 'iris_data_one_hot_encoded_packed' AND
         model = 'iris_multiple_model' AND
@@ -550,8 +455,7 @@ SELECT assert(
         independent_varname = 'attributes' AND
         model_arch_table = 'iris_model_arch' AND
         num_iterations = 6 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         madlib_version is NOT NULL AND
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
@@ -563,10 +467,10 @@ SELECT assert(
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+EXECUTE assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -582,82 +486,47 @@ SELECT assert(
         array_upper(validation_loss, 1) = 1 AND
         array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
+FROM (SELECT * FROM iris_multiple_model_info limit 1) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(
+EXECUTE assert(
   training_loss[6]-training_loss[1] < 0.1 AND
   training_metrics[6]-training_metrics[1] > -0.1,
     'The loss and accuracy should have improved with more iterations.'
 )
 FROM iris_multiple_model_info
 WHERE compile_params like '%lr=0.001%';
+END;
+$$ LANGUAGE plpgsql;
 
--- Test when number of configs(1) is less than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_packed',
-	'iris_multiple_model',
-	'mst_table_1row',
-	3,
-	FALSE,
-	NULL,
-	1,
-	FALSE,
-	'multi_model_name',
-	'multi_model_descr'
-);
-
-SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
-FROM iris_multiple_model_info;
-
-SELECT assert(
-        model_id = 1 AND
-        model_type = 'madlib_keras' AND
-        model_size > 0 AND
-        fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
-        metrics_type = '{accuracy}' AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 3 AND
-        array_upper(training_loss, 1) = 3 AND
-        array_upper(metrics_elapsed_time, 1) = 3,
-        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
-
-SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
-        'Keras Fit Multiple invalid elapsed time calculation.')
-FROM (SELECT * FROM iris_multiple_model_info) info;
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(FALSE);
 
-SELECT assert(
-        name = 'multi_model_name' AND
-        description = 'multi_model_descr' AND
-        metrics_compute_frequency = 1,
-        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
-FROM (SELECT * FROM iris_multiple_model_summary) summary;
+-- Testing with caching
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(TRUE);
 
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+-- Test when number of configs(1) is less than number of segments(3)
+CREATE OR REPLACE FUNCTION test_fit_multiple_less_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
 
--- Testing with caching configs(1) is less than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+EXECUTE madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_1row',
@@ -668,13 +537,13 @@ SELECT madlib_keras_fit_multiple_model(
 	FALSE,
 	'multi_model_name',
 	'multi_model_descr',
-	TRUE
+	caching
 );
 
-SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+EXECUTE assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+EXECUTE assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -688,79 +557,55 @@ SELECT assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
+EXECUTE assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
         'Keras Fit Multiple invalid elapsed time calculation.')
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(
+EXECUTE assert(
         name = 'multi_model_name' AND
         description = 'multi_model_descr' AND
         metrics_compute_frequency = 1,
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
 
--- Test when number of configs(4) larger than number of segments(3)
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
-	3,
-	FALSE
-);
-
--- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
--- but we change it to 'on' in fit_multiple.py. Assert that the value is
--- reset after calling fit_multiple
-SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
-
-SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
-FROM iris_multiple_model_info;
+SELECT test_fit_multiple_less_configs(FALSE);
 
-SELECT assert(
-        model_id = 1 AND
-        model_type = 'madlib_keras' AND
-        model_size > 0 AND
-        metrics_type = '{accuracy}' AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 1 AND
-        array_upper(training_loss, 1) = 1 AND
-        array_upper(metrics_elapsed_time, 1) = 1,
-        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
+-- Testing with caching configs(1) is less than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_less_configs(TRUE);
 
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
-AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+-- Test when number of configs(4) larger than number of segments(3)
+CREATE OR REPLACE FUNCTION test_fit_multiple_more_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
 
--- Test with caching when number of configs(4) larger than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+EXECUTE madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_4row',
 	3,
 	FALSE, NULL, NULL, NULL, NULL, NULL,
-	TRUE
+	caching
 );
 
 -- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
 -- but we change it to 'on' in fit_multiple.py. Assert that the value is
 -- reset after calling fit_multiple
-SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
+PERFORM CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
 
-SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+EXECUTE assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -773,11 +618,20 @@ SELECT assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
 AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(FALSE);
+
+-- Test with caching when number of configs(4) larger than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(TRUE);
 
 -- Test when class values have NULL values
 UPDATE iris_data_packed_summary SET class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL];
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 9990adc..4ccf2bd 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -184,7 +184,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         k = {'SD': {}}
 
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -307,7 +307,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         k = {'SD': {'x_train': x_train, 'y_train': y_train}}
 
         state = starting_image_count
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -447,7 +447,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         k = {'SD': {'x_train': x_train, 'y_train': y_train}}
 
         state = starting_image_count
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -491,7 +491,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
 
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -535,7 +535,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
 
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,


[madlib] 04/04: add use_caching param descr and examples to user docs

Posted by do...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 3cb2305fc8c3c68912c1d5fa397fabb834a8a3a8
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Fri Sep 25 17:35:33 2020 -0700

    add use_caching param descr and examples to user docs
---
 .../madlib_keras_fit_multiple_model.sql_in         | 33 +++++++++++++++-------
 1 file changed, 23 insertions(+), 10 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 1805eb7..5b72672 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -88,14 +88,14 @@ You can set up the models and hyperparameters to try with the
 Model Selection</a> utility to define the unique combinations
 of model architectures, compile and fit parameters.
 
-@note If 'madlib_keras_fit_multiple_model()' is running on GPDB 5 and some versions
+@note 1. If 'madlib_keras_fit_multiple_model()' is running on GPDB 5 and some versions
 of GPDB 6, the database will
 keep adding to the disk space (in proportion to model size) and will only
 release the disk space once the fit multiple query has completed execution.
 This is not the case for GPDB 6.5.0+ where disk space is released during the
 fit multiple query.
 
-@note CUDA GPU memory cannot be released until the process holding it is terminated.
+@note 2. CUDA GPU memory cannot be released until the process holding it is terminated.
 When a MADlib deep learning function is called with GPUs, Greenplum internally
 creates a process (called a slice) which calls TensorFlow to do the computation.
 This process holds the GPU memory until one of the following two things happen:
@@ -121,7 +121,8 @@ madlib_keras_fit_multiple_model(
     metrics_compute_frequency,
     warm_start,
     name,
-    description
+    description,
+    use_caching
     )
 </pre>
 
@@ -231,6 +232,17 @@ madlib_keras_fit_multiple_model(
   <DD>TEXT, default: NULL.
     Free text string to provide a description, if desired.
   </DD>
+
+  <DT>use_caching (optional)</DT>
+  <DD>BOOLEAN, default: FALSE. Use caching of images in memory on the 
+  segment in order to speed up processing. 
+
+  @note
+  When set to TRUE, image byte arrays on each segment are maintained 
+  in cache (SD). This can speed up training significantly, however the 
+  memory usage per segment increases.  In effect, it 
+  requires enough available memory on a segment so that all images 
+  residing on that segment can be read into memory.
 </dl>
 
 <b>Output tables</b>
@@ -1155,7 +1167,7 @@ WHERE q.actual=q.estimated;
 and compute metrics every 3rd iteration using
 the 'metrics_compute_frequency' parameter. This can
 help reduce run time if you do not need metrics
-computed at every iteration.
+computed at every iteration.  Also turn on image caching.
 <pre class="example">
 DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;
 SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_table
@@ -1167,7 +1179,8 @@ SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_
                                                3,                     -- metrics compute frequency
                                                FALSE,                 -- warm start
                                               'Sophie L.',            -- name
-                                              'Model selection for iris dataset'  -- description
+                                              'Model selection for iris dataset',  -- description
+                                               TRUE                   -- use caching
                                              );
 </pre>
 View the model summary:
@@ -1282,7 +1295,8 @@ SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_
                                                1,                     -- metrics compute frequency
                                                TRUE,                  -- warm start
                                               'Sophie L.',            -- name
-                                              'Simple MLP for iris dataset'  -- description
+                                              'Simple MLP for iris dataset',  -- description
+                                               TRUE                   -- use caching
                                              );
 SELECT * FROM iris_multi_model_summary;
 </pre>
@@ -1380,10 +1394,9 @@ inference runtimes will be proportionally faster as the number of segments incre
 Supun Nakandala, Yuhao Zhang, and Arun Kumar, ACM SIGMOD 2019 DEEM Workshop,
 https://adalabucsd.github.io/papers/2019_Cerebro_DEEM.pdf
 
-[2] "Resource-Efficient and Reproducible Model Selection on Deep Learning Systems,"
-Supun Nakandala, Yuhao Zhang, and Arun Kumar, Technical Report, Computer Science and
-Engineering, University of California, San Diego
-https://adalabucsd.github.io/papers/TR_2019_Cerebro.pdf
+[2] "Cerebro: A Data System for Optimized Deep Learning Model Selection,"
+Supun Nakandala, Yuhao Zhang, and Arun Kumar, Proceedings of the VLDB Endowment (2020), Vol. 13, No. 11
+https://adalabucsd.github.io/papers/2020_Cerebro_VLDB.pdf
 
 [3] https://keras.io/
 


[madlib] 01/04: DL: Implement caching for fit_multiple_model

Posted by do...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit a119fe882337f68e4105f5cd24179b4d87121e00
Author: Ekta Khanna <ek...@vmware.com>
AuthorDate: Thu Sep 10 12:31:28 2020 -0700

    DL: Implement caching for fit_multiple_model
    
    Currently passing around independent and dependent vars to the
    transition function is what takes up most of the time.
    As part of this commit, add a new fit_multipl_transition function that
    reads all the rows (for each seg) into the cache(SD) for the very first
    hop and for each subsequent hop/iteration, the data is read from the
    cache instead of table and cleared out at the final training call. This
    helps reduces the time to pass along the data to the transition function.
    Since, the data is cached into memory, the memory usage per segment
    increases significantly. To avoid this, a new optional param
    `use_caching` is added to madlib_keras_fit_multiple_model(), that can be
    set to TRUE if the memory on each segment meets the following
    calculation:
    
       IND_SZ (indep var size of each row) = ((image_dimension)*4)*(#of images per buffer)
       DEP_SZ (indep var size of each row) = (#DEP_VAR * 4)*(#of images per buffer)
       memory_data = (#seg_per_host) * (#rows_per_seg * IND_SZ) + (#seg_per_host) * (#rows_per_seg * DEP_SZ)
       memory_model = model_size * #models_per_seg * #seg_per_host
       total_memory = memory_data + memory_model
---
 .../modules/deep_learning/madlib_keras.py_in       |  89 ++++++++-
 .../madlib_keras_fit_multiple_model.py_in          |  58 ++++--
 .../madlib_keras_fit_multiple_model.sql_in         |  17 +-
 .../test/madlib_keras_model_selection.sql_in       | 200 +++++++++++++++++++-
 .../test/unit_tests/test_madlib_keras.py_in        | 202 ++++++++++++++++++++-
 5 files changed, 545 insertions(+), 21 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index e8eac71..00889b9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -523,7 +523,7 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
-                                  agg_image_count, total_images)
+                                       agg_image_count, total_images)
     if is_last_row:
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
@@ -531,6 +531,93 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
 
     return return_state
 
+def fit_multiple_transition(state, dependent_var, independent_var, dependent_var_shape,
+                             independent_var_shape, model_architecture,
+                             compile_params, fit_params, dist_key, dist_key_mapping,
+                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
+                             accessible_gpus_for_seg, prev_serialized_weights,
+                             is_final_training_call, custom_function_map=None, **kwargs):
+    """
+    This transition function is called when caching is called for
+    madlib_keras_fit_multiple_model().
+    The input params: dependent_var, independent_var are passed in
+    as None and dependent_var_shape, independent_var_shape as [0]
+    for all hops except the very firt hop
+    Some things to note in this function are:
+    - prev_serialized_weights can be passed in as None for the
+      very first hop and the final training call
+    - x_train, y_train and cache_set is cleared from SD for
+      final_training_call = TRUE
+    """
+    if not state:
+        agg_image_count = 0
+    else:
+        agg_image_count = float(state)
+
+    SD = kwargs['SD']
+    is_cache_set = 'cache_set' in SD
+
+    # Prepare the data
+    if is_cache_set:
+        if 'x_train' not in SD or 'y_train' not in SD:
+            plpy.error("cache not populated properly.")
+        total_images = None
+        is_last_row = True
+    else:
+        if not independent_var or not dependent_var:
+            return state
+        if 'x_train' not in SD:
+            SD['x_train'] = list()
+            SD['y_train'] = list()
+        agg_image_count += independent_var_shape[0]
+        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+                                                          images_per_seg)
+        is_last_row = agg_image_count == total_images
+        if is_last_row:
+            SD['cache_set'] = True
+        x_train_current = np_array_float32(independent_var, independent_var_shape)
+        y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+        SD['x_train'].append(x_train_current)
+        SD['y_train'].append(y_train_current)
+
+    # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
+    # But if the weights are None, we do not want to set any model. So early return in that case
+    if prev_serialized_weights is None:
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+        return float(agg_image_count)
+
+    segment_model = None
+    if is_last_row:
+        device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
+        segment_model, sess = get_init_model_and_sess(SD, device_name,
+                                                      accessible_gpus_for_seg[current_seg_id],
+                                                      segments_per_host,
+                                                      model_architecture, compile_params,
+                                                      custom_function_map)
+        set_model_weights(segment_model, prev_serialized_weights)
+
+        fit_params = parse_and_validate_fit_params(fit_params)
+        for i in range(len(SD['x_train'])):
+            # Fit segment model on data
+            segment_model.fit(SD['x_train'][i], SD['y_train'][i], **fit_params)
+
+
+    return_state = get_state_to_return(segment_model, is_last_row, True,
+                                       agg_image_count, total_images)
+
+    if is_last_row:
+        SD_STORE.clear_SD(SD)
+        clear_keras_session(sess)
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+
+    return return_state
+
 def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image_count,
                         total_images):
     """
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index b847550..c821474 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -81,7 +81,7 @@ class FitMultipleModel():
                  model_selection_table, num_iterations,
                  use_gpus=False, validation_table=None,
                  metrics_compute_frequency=None, warm_start=False, name="",
-                 description="", **kwargs):
+                 description="", use_caching=False, **kwargs):
         # set the random seed for visit order/scheduling
         random.seed(1)
         if is_platform_pg():
@@ -97,6 +97,7 @@ class FitMultipleModel():
         self.metrics_compute_frequency = metrics_compute_frequency
         self.name = name
         self.description = description
+        self.use_caching = use_caching if use_caching is not None else False
         self.module_name = 'madlib_keras_fit_multiple_model'
         self.schema_madlib = schema_madlib
         self.version = madlib_version(self.schema_madlib)
@@ -115,6 +116,7 @@ class FitMultipleModel():
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
         self.use_gpus = use_gpus
         self.segments_per_host = get_segments_per_host()
+        self.cached_source_table = unique_string('cached_source_table')
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                 self.schema_madlib, self.segments_per_host, self.module_name)
@@ -233,7 +235,7 @@ class FitMultipleModel():
                 self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1)
                 if mst_idx == 0:
                     start_iteration = time.time()
-                self.run_training(mst_idx)
+                self.run_training(mst_idx, mst_idx==0 and iter==1)
                 if mst_idx == (total_msts - 1):
                     end_iteration = time.time()
                     self.info_str = "\tTime for training in iteration " \
@@ -249,6 +251,7 @@ class FitMultipleModel():
                 if self.validation_table:
                     self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
+        plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self))
 
     def evaluate_model(self, epoch, table, is_train):
         if is_train:
@@ -594,7 +597,7 @@ class FitMultipleModel():
             if self.validation_table:
                 self.update_info_table(mst, False)
 
-    def run_training(self, mst_idx):
+    def run_training(self, mst_idx, is_very_first_hop):
         # NOTE: In the DL module, we want to avoid CREATING TEMP tables
         # (creates a slice which stays until the session is disconnected)
         # or minimize writing queries that generate plans with Motions (creating
@@ -622,12 +625,39 @@ class FitMultipleModel():
                    **locals())
         plpy.execute(mst_weights_query)
         use_gpus = self.use_gpus if self.use_gpus else False
+        dep_shape_col = self.dep_shape_col
+        ind_shape_col = self.ind_shape_col
+        dep_var = mb_dep_var_col
+        indep_var = mb_indep_var_col
+        source_table = self.source_table
+        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+        if self.use_caching:
+            # Caching populates the independent_var and dependent_var into the cache on the very first hop
+            # For the very_first_hop, we want to run the transition function on all segments, including
+            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
+            # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
+            # table to join for referencing the dist_key
+            if is_very_first_hop:
+                plpy.execute("""
+                    DROP TABLE IF EXISTS {self.cached_source_table};
+                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
+                    """.format(self=self, dist_key_col=dist_key_col))
+            else:
+                dep_shape_col = 'ARRAY[0]'
+                ind_shape_col = 'ARRAY[0]'
+                dep_var = 'NULL'
+                indep_var = 'NULL'
+                source_table = self.cached_source_table
+            if is_very_first_hop or self.is_final_training_call:
+                where_clause = ""
+
         uda_query = """
             CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
             SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
                 {mb_indep_var_col},
-                {self.dep_shape_col},
-                {self.ind_shape_col},
+                {dep_shape_col},
+                {ind_shape_col},
                 {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
                 {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
                 {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
@@ -639,21 +669,27 @@ class FitMultipleModel():
                 {use_gpus}::BOOLEAN,
                 ARRAY{self.accessible_gpus_for_seg},
                 {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_iteration}::BOOLEAN,
+                {is_final_training_call}::BOOLEAN,
+                {use_caching}::BOOLEAN,
                 {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
                 )::BYTEA AS {self.model_weights_col},
                 {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
                 ,src.{dist_key_col} AS {dist_key_col}
-            FROM {self.source_table} src JOIN {self.mst_weights_tbl}
+            FROM {source_table} src JOIN {self.mst_weights_tbl}
                 USING ({dist_key_col})
-            WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
+            {where_clause}
             GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
             DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=mb_dep_var_col,
-                       mb_indep_var_col=mb_indep_var_col,
-                       is_final_iteration=True,
+            """.format(mb_dep_var_col=dep_var,
+                       mb_indep_var_col=indep_var,
+                       dep_shape_col=dep_shape_col,
+                       ind_shape_col=ind_shape_col,
+                       is_final_training_call=self.is_final_training_call,
+                       use_caching=self.use_caching,
                        dist_key_col=dist_key_col,
                        use_gpus=use_gpus,
+                       source_table=source_table,
+                       where_clause=where_clause,
                        self=self
                        )
         plpy.execute(uda_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 392a3be..5a50733 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1416,7 +1416,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
     name                    VARCHAR,
-    description             VARCHAR
+    description             VARCHAR,
+    use_caching             BOOLEAN DEFAULT FALSE
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     from utilities.control import SetGUC
@@ -1506,13 +1507,17 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
     use_gpus                   BOOLEAN,
-    accessible_gpus_for_seg               INTEGER[],
+    accessible_gpus_for_seg    INTEGER[],
     prev_serialized_weights    BYTEA,
-    is_final_iteration         BOOLEAN,
+    is_final_training_call     BOOLEAN,
+    use_caching                BOOLEAN,
     custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+    if use_caching:
+        return madlib_keras.fit_multiple_transition(**globals())
+    else:
+        return madlib_keras.fit_transition(is_final_iteration = True, is_multiple_model = True, **globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1533,6 +1538,7 @@ DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step_multiple_model(
     INTEGER[],
     BYTEA,
     BOOLEAN,
+    BOOLEAN,
     BYTEA);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* dependent_var */              BYTEA,
@@ -1550,7 +1556,8 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* use_gpus */                   BOOLEAN,
     /* accessible_gpus_for_seg */    INTEGER[],
     /* prev_serialized_weights */    BYTEA,
-    /* is_final_iteration */         BOOLEAN,
+    /* is_final_training_call */     BOOLEAN,
+    /* use_caching */                BOOLEAN,
     /* custom_function_obj_map */    BYTEA
 )(
     STYPE=BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 82b2647..31bd9d6 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -375,6 +375,39 @@ SELECT assert(
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_one_hot_encoded_packed',
+	'iris_multiple_model',
+	'mst_table_4row',
+	3,
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+	TRUE
+);
+
+SELECT assert(
+        model_arch_table = 'iris_model_arch' AND
+        validation_table is NULL AND
+        model_info = 'iris_multiple_model_info' AND
+        source_table = 'iris_data_one_hot_encoded_packed' AND
+        model = 'iris_multiple_model' AND
+        model_selection_table = 'mst_table_4row' AND
+        object_table IS NULL AND
+        dependent_varname = 'class_one_hot_encoded' AND
+        independent_varname = 'attributes' AND
+        madlib_version is NOT NULL AND
+        num_iterations = 3 AND
+        start_training_time < now() AND
+        end_training_time < now() AND
+        dependent_vartype = 'integer[]' AND
+        num_classes = NULL AND
+        class_values = NULL AND
+        normalizing_const = 1 AND
+        metrics_iters = ARRAY[3],
+        'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
 -- Test the output table created are all persistent(not unlogged)
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
@@ -495,6 +528,85 @@ SELECT assert(
 FROM iris_multiple_model_info
 WHERE compile_params like '%lr=0.001%';
 
+-- Testing with caching
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table',
+	6,
+	FALSE,
+	'iris_data_one_hot_encoded_packed', NULL, NULL, NULL, NULL,
+	TRUE
+);
+
+SELECT assert(
+        source_table = 'iris_data_packed' AND
+        validation_table = 'iris_data_one_hot_encoded_packed' AND
+        model = 'iris_multiple_model' AND
+        model_info = 'iris_multiple_model_info' AND
+        dependent_varname = 'class_text' AND
+        independent_varname = 'attributes' AND
+        model_arch_table = 'iris_model_arch' AND
+        num_iterations = 6 AND
+        start_training_time < now() AND
+        end_training_time < now() AND
+        madlib_version is NOT NULL AND
+        num_classes = 3 AND
+        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+        dependent_vartype LIKE '%char%' AND
+        normalizing_const = 1 AND
+        name IS NULL AND
+        description IS NULL AND
+        metrics_compute_frequency = 6,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        fit_params = $MAD$batch_size=50, epochs=1$MAD$::text AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
+        validation_metrics_final >= 0  AND
+        validation_loss_final  >= 0  AND
+        array_upper(validation_metrics, 1) = 1 AND
+        array_upper(validation_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
+
+SELECT assert(
+  training_loss[6]-training_loss[1] < 0.1 AND
+  training_metrics[6]-training_metrics[1] > -0.1,
+    'The loss and accuracy should have improved with more iterations.'
+)
+FROM iris_multiple_model_info
+WHERE compile_params like '%lr=0.001%';
+
 -- Test when number of configs(1) is less than number of segments(3)
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
 SELECT madlib_keras_fit_multiple_model(
@@ -543,6 +655,55 @@ SELECT assert(cnt = 1,
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
 
+-- Testing with caching configs(1) is less than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table_1row',
+	3,
+	FALSE,
+	NULL,
+	1,
+	FALSE,
+	'multi_model_name',
+	'multi_model_descr',
+	TRUE
+);
+
+SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 3 AND
+        array_upper(training_loss, 1) = 3 AND
+        array_upper(metrics_elapsed_time, 1) = 3,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
+        'Keras Fit Multiple invalid elapsed time calculation.')
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(
+        name = 'multi_model_name' AND
+        description = 'multi_model_descr' AND
+        metrics_compute_frequency = 1,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+
 -- Test when number of configs(4) larger than number of segments(3)
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
 SELECT madlib_keras_fit_multiple_model(
@@ -580,6 +741,44 @@ FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
 AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
 
+-- Test with caching when number of configs(4) larger than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model',
+	'mst_table_4row',
+	3,
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+	TRUE
+);
+
+-- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
+-- but we change it to 'on' in fit_multiple.py. Assert that the value is
+-- reset after calling fit_multiple
+SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
+
+SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  >= 0  AND
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
+
+SELECT assert(cnt = 1,
+	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
+FROM (SELECT count(*) cnt FROM iris_multiple_model_info
+WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
+AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+
 -- Test when class values have NULL values
 UPDATE iris_data_packed_summary SET class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL];
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
@@ -606,7 +805,6 @@ CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed as select
 CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed_summary as select * from iris_data_packed_summary;
 
 -- do not drop the output table created in the previous test
-SELECT count(*) from iris_multiple_model;
 SELECT madlib_keras_fit_multiple_model(
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed',
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 6dacdcd..9990adc 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -145,7 +145,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
-    def test_fit_transition_multiple_model_first_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_first_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -172,6 +172,36 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
+    def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+        starting_image_count = 0
+        ending_image_count = len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+
+        k = {'SD': {}}
+
+        new_state = self.subject.fit_multiple_transition(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session should only be called for the last row
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session must not be called for the first buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -228,7 +258,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_middle_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -259,6 +289,41 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session is only called for the last buffer
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the middle buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -327,7 +392,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         #  but not in postgres
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_last_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_last_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -362,6 +427,137 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         #  but not in postgres
         self.assertEqual(1, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_final_training_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue('x_train' not in k['SD'])
+        self.assertTrue('y_train' not in k['SD'])
+
     def test_fit_transition_first_buffer_pass_pg(self):
         self._test_fit_transition_first_buffer_pass(True)