You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by do...@apache.org on 2020/09/29 18:08:03 UTC

[madlib] 02/04: Address review comments

This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit fe58ef58302b54469002c18c793528ce8455356c
Author: Ekta Khanna <ek...@vmware.com>
AuthorDate: Tue Sep 22 18:14:05 2020 -0700

    Address review comments
---
 .../modules/deep_learning/madlib_keras.py_in       |   4 +-
 .../madlib_keras_fit_multiple_model.sql_in         |   2 +-
 .../test/madlib_keras_model_selection.sql_in       | 304 ++++++---------------
 .../test/unit_tests/test_madlib_keras.py_in        |  10 +-
 4 files changed, 87 insertions(+), 233 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 00889b9..0d55028 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -531,7 +531,7 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
 
     return return_state
 
-def fit_multiple_transition(state, dependent_var, independent_var, dependent_var_shape,
+def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
                              independent_var_shape, model_architecture,
                              compile_params, fit_params, dist_key, dist_key_mapping,
                              current_seg_id, segments_per_host, images_per_seg, use_gpus,
@@ -542,7 +542,7 @@ def fit_multiple_transition(state, dependent_var, independent_var, dependent_var
     madlib_keras_fit_multiple_model().
     The input params: dependent_var, independent_var are passed in
     as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very firt hop
+    for all hops except the very first hop
     Some things to note in this function are:
     - prev_serialized_weights can be passed in as None for the
       very first hop and the final training call
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 5a50733..1805eb7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1515,7 +1515,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     if use_caching:
-        return madlib_keras.fit_multiple_transition(**globals())
+        return madlib_keras.fit_multiple_transition_caching(**globals())
     else:
         return madlib_keras.fit_transition(is_final_iteration = True, is_multiple_model = True, **globals())
 $$ LANGUAGE plpythonu
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 31bd9d6..d39f2a0 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -344,16 +344,20 @@ SELECT load_model_selection_table(
 );
 
 -- Test for one-hot encoded input data
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_one_hot_encoded_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
+CREATE OR REPLACE FUNCTION test_fit_multiple_one_hot_encoded_input(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+EXECUTE madlib_keras_fit_multiple_model(
+	'iris_data_one_hot_encoded_packed'::VARCHAR,
+	'iris_multiple_model'::VARCHAR,
+	'mst_table_4row'::VARCHAR,
 	3,
-	FALSE
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+    caching
 );
 
-SELECT assert(
+EXECUTE assert(
         model_arch_table = 'iris_model_arch' AND
         validation_table is NULL AND
         model_info = 'iris_multiple_model_info' AND
@@ -365,8 +369,7 @@ SELECT assert(
         independent_varname = 'attributes' AND
         madlib_version is NOT NULL AND
         num_iterations = 3 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         dependent_vartype = 'integer[]' AND
         num_classes = NULL AND
         class_values = NULL AND
@@ -374,39 +377,15 @@ SELECT assert(
         metrics_iters = ARRAY[3],
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
+END;
+$$ language plpgsql;
 
--- Testing with caching
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_one_hot_encoded_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
-	3,
-	FALSE, NULL, NULL, NULL, NULL, NULL,
-	TRUE
-);
+SELECT test_fit_multiple_one_hot_encoded_input(FALSE);
 
-SELECT assert(
-        model_arch_table = 'iris_model_arch' AND
-        validation_table is NULL AND
-        model_info = 'iris_multiple_model_info' AND
-        source_table = 'iris_data_one_hot_encoded_packed' AND
-        model = 'iris_multiple_model' AND
-        model_selection_table = 'mst_table_4row' AND
-        object_table IS NULL AND
-        dependent_varname = 'class_one_hot_encoded' AND
-        independent_varname = 'attributes' AND
-        madlib_version is NOT NULL AND
-        num_iterations = 3 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
-        dependent_vartype = 'integer[]' AND
-        num_classes = NULL AND
-        class_values = NULL AND
-        normalizing_const = 1 AND
-        metrics_iters = ARRAY[3],
-        'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
-FROM (SELECT * FROM iris_multiple_model_summary) summary;
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_one_hot_encoded_input(TRUE);
 
 -- Test the output table created are all persistent(not unlogged)
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
@@ -451,97 +430,23 @@ SELECT assert(
 FROM (SELECT * FROM mst_object_table_summary) summary;
 
 -- Test when number of configs(3) equals number of segments(3)
-DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_packed',
-	'iris_multiple_model',
-	'mst_table',
-	6,
-	FALSE,
-	'iris_data_one_hot_encoded_packed'
-);
-
-SELECT assert(
-        source_table = 'iris_data_packed' AND
-        validation_table = 'iris_data_one_hot_encoded_packed' AND
-        model = 'iris_multiple_model' AND
-        model_info = 'iris_multiple_model_info' AND
-        dependent_varname = 'class_text' AND
-        independent_varname = 'attributes' AND
-        model_arch_table = 'iris_model_arch' AND
-        num_iterations = 6 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
-        madlib_version is NOT NULL AND
-        num_classes = 3 AND
-        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
-        dependent_vartype LIKE '%char%' AND
-        normalizing_const = 1 AND
-        name IS NULL AND
-        description IS NULL AND
-        metrics_compute_frequency = 6,
-        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
-FROM (SELECT * FROM iris_multiple_model_summary) summary;
-
-SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
-FROM iris_multiple_model_info;
-
-SELECT assert(
-        model_id = 1 AND
-        model_type = 'madlib_keras' AND
-        model_size > 0 AND
-        fit_params = $MAD$batch_size=50, epochs=1$MAD$::text AND
-        metrics_type = '{accuracy}' AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 1 AND
-        array_upper(training_loss, 1) = 1 AND
-        validation_metrics_final >= 0  AND
-        validation_loss_final  >= 0  AND
-        array_upper(validation_metrics, 1) = 1 AND
-        array_upper(validation_loss, 1) = 1 AND
-        array_upper(metrics_elapsed_time, 1) = 1,
-        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
-
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
-
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
-
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
-
-SELECT assert(
-  training_loss[6]-training_loss[1] < 0.1 AND
-  training_metrics[6]-training_metrics[1] > -0.1,
-    'The loss and accuracy should have improved with more iterations.'
-)
-FROM iris_multiple_model_info
-WHERE compile_params like '%lr=0.001%';
+CREATE OR REPLACE FUNCTION test_fit_multiple_equal_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
 
--- Testing with caching
-DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
+EXECUTE setseed(0);
+EXECUTE madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table',
 	6,
 	FALSE,
 	'iris_data_one_hot_encoded_packed', NULL, NULL, NULL, NULL,
-	TRUE
+	caching
 );
 
-SELECT assert(
+EXECUTE assert(
         source_table = 'iris_data_packed' AND
         validation_table = 'iris_data_one_hot_encoded_packed' AND
         model = 'iris_multiple_model' AND
@@ -550,8 +455,7 @@ SELECT assert(
         independent_varname = 'attributes' AND
         model_arch_table = 'iris_model_arch' AND
         num_iterations = 6 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         madlib_version is NOT NULL AND
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
@@ -563,10 +467,10 @@ SELECT assert(
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+EXECUTE assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -582,82 +486,47 @@ SELECT assert(
         array_upper(validation_loss, 1) = 1 AND
         array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
+FROM (SELECT * FROM iris_multiple_model_info limit 1) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(
+EXECUTE assert(
   training_loss[6]-training_loss[1] < 0.1 AND
   training_metrics[6]-training_metrics[1] > -0.1,
     'The loss and accuracy should have improved with more iterations.'
 )
 FROM iris_multiple_model_info
 WHERE compile_params like '%lr=0.001%';
+END;
+$$ LANGUAGE plpgsql;
 
--- Test when number of configs(1) is less than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_packed',
-	'iris_multiple_model',
-	'mst_table_1row',
-	3,
-	FALSE,
-	NULL,
-	1,
-	FALSE,
-	'multi_model_name',
-	'multi_model_descr'
-);
-
-SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
-FROM iris_multiple_model_info;
-
-SELECT assert(
-        model_id = 1 AND
-        model_type = 'madlib_keras' AND
-        model_size > 0 AND
-        fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
-        metrics_type = '{accuracy}' AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 3 AND
-        array_upper(training_loss, 1) = 3 AND
-        array_upper(metrics_elapsed_time, 1) = 3,
-        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
-
-SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
-        'Keras Fit Multiple invalid elapsed time calculation.')
-FROM (SELECT * FROM iris_multiple_model_info) info;
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(FALSE);
 
-SELECT assert(
-        name = 'multi_model_name' AND
-        description = 'multi_model_descr' AND
-        metrics_compute_frequency = 1,
-        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
-FROM (SELECT * FROM iris_multiple_model_summary) summary;
+-- Testing with caching
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(TRUE);
 
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+-- Test when number of configs(1) is less than number of segments(3)
+CREATE OR REPLACE FUNCTION test_fit_multiple_less_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
 
--- Testing with caching configs(1) is less than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+EXECUTE madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_1row',
@@ -668,13 +537,13 @@ SELECT madlib_keras_fit_multiple_model(
 	FALSE,
 	'multi_model_name',
 	'multi_model_descr',
-	TRUE
+	caching
 );
 
-SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+EXECUTE assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+EXECUTE assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -688,79 +557,55 @@ SELECT assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
+EXECUTE assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
         'Keras Fit Multiple invalid elapsed time calculation.')
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(
+EXECUTE assert(
         name = 'multi_model_name' AND
         description = 'multi_model_descr' AND
         metrics_compute_frequency = 1,
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
 
--- Test when number of configs(4) larger than number of segments(3)
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
-	3,
-	FALSE
-);
-
--- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
--- but we change it to 'on' in fit_multiple.py. Assert that the value is
--- reset after calling fit_multiple
-SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
-
-SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
-FROM iris_multiple_model_info;
+SELECT test_fit_multiple_less_configs(FALSE);
 
-SELECT assert(
-        model_id = 1 AND
-        model_type = 'madlib_keras' AND
-        model_size > 0 AND
-        metrics_type = '{accuracy}' AND
-        training_metrics_final >= 0  AND
-        training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 1 AND
-        array_upper(training_loss, 1) = 1 AND
-        array_upper(metrics_elapsed_time, 1) = 1,
-        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
+-- Testing with caching configs(1) is less than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_less_configs(TRUE);
 
-SELECT assert(cnt = 1,
-	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
-FROM (SELECT count(*) cnt FROM iris_multiple_model_info
-WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
-AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+-- Test when number of configs(4) larger than number of segments(3)
+CREATE OR REPLACE FUNCTION test_fit_multiple_more_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
 
--- Test with caching when number of configs(4) larger than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+EXECUTE madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_4row',
 	3,
 	FALSE, NULL, NULL, NULL, NULL, NULL,
-	TRUE
+	caching
 );
 
 -- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
 -- but we change it to 'on' in fit_multiple.py. Assert that the value is
 -- reset after calling fit_multiple
-SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
+PERFORM CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
 
-SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+EXECUTE assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -773,11 +618,20 @@ SELECT assert(
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(cnt = 1,
+EXECUTE assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
 AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(FALSE);
+
+-- Test with caching when number of configs(4) larger than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(TRUE);
 
 -- Test when class values have NULL values
 UPDATE iris_data_packed_summary SET class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL];
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 9990adc..4ccf2bd 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -184,7 +184,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         k = {'SD': {}}
 
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -307,7 +307,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         k = {'SD': {'x_train': x_train, 'y_train': y_train}}
 
         state = starting_image_count
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -447,7 +447,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         k = {'SD': {'x_train': x_train, 'y_train': y_train}}
 
         state = starting_image_count
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -491,7 +491,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
 
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
@@ -535,7 +535,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
 
-        new_state = self.subject.fit_multiple_transition(
+        new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,