You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/11/25 03:08:37 UTC

[GitHub] [madlib] reductionista commented on a change in pull request #522: DL: Remove keras dependency

reductionista commented on a change in pull request #522:
URL: https://github.com/apache/madlib/pull/522#discussion_r529100647



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -529,9 +547,12 @@ def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
-    if is_last_row:
-        if is_final_iteration or is_multiple_model:
-            SD_STORE.clear_SD(SD)
+    # if is_last_row:
+    #     if is_final_iteration or is_multiple_model:
+    #         GD_STORE.clear_GD(GD)

Review comment:
       Remove commented lines?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -899,11 +919,14 @@ def internal_keras_eval_transition(state, dependent_var, independent_var,
                                                       images_per_seg)
 
     if agg_image_count == total_images and is_final_iteration:
+        plpy.info('clearing session and gd')
         K.clear_session()
         sess.close()
-        SD_STORE.clear_SD(SD)
-        del segment_model
+        GD_STORE.clear_GD(GD)

Review comment:
       Since we're changing this anyway, can we rename this method from GD_STORE.clear_GD() to GD_STORE.clear()? Seems redundant to have 2 GD's in the name, especially when you still have to pass a third "GD" as an argument.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -49,35 +49,41 @@ from tensorflow.keras.models import *
 from tensorflow.keras.optimizers import *
 from tensorflow.keras.regularizers import *
 
-class SD_STORE:
+class GD_STORE:
     SESS = 'sess'
     SEGMENT_MODEL = 'segment_model'
 
     @staticmethod
-    def init_SD(SD, sess, segment_model):
-        SD[SD_STORE.SESS] = sess
-        SD[SD_STORE.SEGMENT_MODEL] = segment_model
+    def init_GD(GD, sess, segment_model):
+        GD[GD_STORE.SESS] = sess
+        GD[GD_STORE.SEGMENT_MODEL] = segment_model
 
     @staticmethod
-    def clear_SD(SD):
-        del SD[SD_STORE.SEGMENT_MODEL]
-        del SD[SD_STORE.SESS]
+    def clear_GD(GD):
+        del GD[GD_STORE.SEGMENT_MODEL]
+        del GD[GD_STORE.SESS]
 
-def get_init_model_and_sess(SD, device_name, gpu_count, segments_per_host,
+def get_init_model_and_sess(GD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params, custom_function_map):
+    # plpy.info("GD keys {}".format(GD.keys()))
     # If a live session is present, re-use it. Otherwise, recreate it.
-    if SD_STORE.SESS in SD :
-        if SD_STORE.SEGMENT_MODEL not in SD:
-            plpy.error("Session and model should exist in SD after the first row"
+    if GD_STORE.SESS in GD:
+        # plpy.info("found sess in GD")
+        if GD_STORE.SEGMENT_MODEL not in GD:
+            plpy.error("Session and model should exist in GD after the first row"
                        "of the first iteration")
-        sess = SD[SD_STORE.SESS]
-        segment_model = SD[SD_STORE.SEGMENT_MODEL]
+        sess = GD[GD_STORE.SESS]
+        segment_model = GD[GD_STORE.SEGMENT_MODEL]
         K.set_session(sess)
     else:
+        # plpy.info("creating sess")
         sess = get_keras_session(device_name, gpu_count, segments_per_host)
         K.set_session(sess)
         segment_model = init_model(model_architecture, compile_params, custom_function_map)
-        SD_STORE.init_SD(SD, sess, segment_model)
+        GD_STORE.init_GD(GD, sess, segment_model)
+
+    plpy.info('sess object {}'.format(sess))
+    plpy.info('sess.graph object {}'.format(sess.graph))

Review comment:
       Remove plpy.info's

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple clears the tf session
+        # and graph at the last buffer.
+        # It is fetched prior to calling the fit_transition(from fit_multiple) as when we create
+        # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph.
+        # This enables us to do the not equals assert.
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+        first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        # iteration 2 (last iteration)
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+        last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+    def test_eval_multiple_any_iteration(self):
+        # This test tests 2 things:
+        # 1. Calling eval_transition from fit_multiple
+        # 2. Calling eval_transition from evaluate directly
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and graph
+        # It is fetched prior to calling the eval_transition as when we create a session inside
+        # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1.
+        # This enables us to do the not equals assert.
+        eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+        eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs)
+        eval_iter_keras_sess2 = self.subject.K.get_session()
+        eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+        self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+        self._assert_gd_cleared(GD)
+
+    def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False, **kwargs):
+        self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+                                                               **kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+        eval_first_buffer_keras_sess = kwargs['GD']['sess']
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+        eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+        if not called_from_fit_multiple:
+            self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+            self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+        self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+                                                                **kwargs )
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+        self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+                                                              **kwargs)
+        if final_iteration:
+            self._assert_gd_cleared(kwargs['GD'])
+            self.assertTrue(eval_first_buffer_keras_sess._closed)
+        else:
+            self._assert_gd_is_valid(kwargs['GD'])
+            self.assertFalse(eval_first_buffer_keras_sess._closed)
+        return eval_first_buffer_keras_sess
+
+    def _run_fit_iteration(self, **kwargs):
+        self._test_fit_transition_first_buffer_pass(**kwargs)
+        gd_first_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_first_buffer)
+        iter_sess = gd_first_buffer['sess']
+        self.assertFalse(iter_sess._closed)
+        self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+        self._test_fit_transition_middle_buffer_pass(**kwargs)
+        gd_middle_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_middle_buffer)
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_last_buffer_pass(**kwargs)
+        gd_last_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_last_buffer)
+        self.assertFalse(iter_sess._closed)
+        return iter_sess
+
+    def _run_fit_multiple_iteration(self, **kwargs):
+        self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        iter_sess = kwargs['GD']['sess']
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+        self._assert_gd_cleared(kwargs['GD'])
+        self.assertTrue(iter_sess._closed)
+        return iter_sess
+
+    def _assert_keras_session_same_as_gd_session(self, gd):
+        sess = self.subject.K.get_session()
+        self.assertEquals(sess, gd['sess'])

Review comment:
       Instead of assuming we have the right session in the code, and verifying it in a unit test, I wonder if we could do something like this:
   
   - Each time we retrieve the session from GD, check whether it matches K.get_session()
        - If they match, just continue
        - If they don't match, check some kind of flag that indicates CMAKE_BUILD_TYPE=debug (?)
                 - if CMAKE_BUILD_TYPE=debug, raise Exception so that it fails in the pipeline
                 - If CMAKE_BUILD_TYPE=release, just call K.set_session(sess) and continue
   
   This seems like more ideal behavior than crashing if they don't match.  I imagine if this does fail, it would probably not be due to a code change but due to some kind of weird race condition on GUC setting a user has in their environment (like backend session timeout, etc.?)... which is not something a unit test is likely to catch anyway, since it's always run under the same isolated conditions that we've already tested.
   
   Only unknown would be how exactly to check for CMAKE_BUILD_TYPE in python, maybe with m4?  Or writing a small `is_madlib_debug_build()` helper function in LANGUAGE C that returns True or False?

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple clears the tf session
+        # and graph at the last buffer.
+        # It is fetched prior to calling the fit_transition(from fit_multiple) as when we create
+        # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph.
+        # This enables us to do the not equals assert.
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+        first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        # iteration 2 (last iteration)
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+        last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+    def test_eval_multiple_any_iteration(self):
+        # This test tests 2 things:
+        # 1. Calling eval_transition from fit_multiple
+        # 2. Calling eval_transition from evaluate directly
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and graph
+        # It is fetched prior to calling the eval_transition as when we create a session inside
+        # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1.
+        # This enables us to do the not equals assert.
+        eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+        eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs)
+        eval_iter_keras_sess2 = self.subject.K.get_session()
+        eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+        self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+        self._assert_gd_cleared(GD)
+
+    def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False, **kwargs):
+        self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+                                                               **kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+        eval_first_buffer_keras_sess = kwargs['GD']['sess']
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+        eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+        if not called_from_fit_multiple:
+            self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+            self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+        self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+                                                                **kwargs )
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+        self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+                                                              **kwargs)
+        if final_iteration:
+            self._assert_gd_cleared(kwargs['GD'])
+            self.assertTrue(eval_first_buffer_keras_sess._closed)
+        else:
+            self._assert_gd_is_valid(kwargs['GD'])
+            self.assertFalse(eval_first_buffer_keras_sess._closed)
+        return eval_first_buffer_keras_sess
+
+    def _run_fit_iteration(self, **kwargs):
+        self._test_fit_transition_first_buffer_pass(**kwargs)
+        gd_first_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_first_buffer)
+        iter_sess = gd_first_buffer['sess']
+        self.assertFalse(iter_sess._closed)
+        self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+        self._test_fit_transition_middle_buffer_pass(**kwargs)
+        gd_middle_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_middle_buffer)
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_last_buffer_pass(**kwargs)
+        gd_last_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_last_buffer)
+        self.assertFalse(iter_sess._closed)
+        return iter_sess
+
+    def _run_fit_multiple_iteration(self, **kwargs):
+        self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        iter_sess = kwargs['GD']['sess']
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+        self._assert_gd_cleared(kwargs['GD'])
+        self.assertTrue(iter_sess._closed)
+        return iter_sess

Review comment:
       I notice we still close the tensorflow session and re-open a new one after each hop.  Did we try keeping it open and just calling clear_session to start a new graph instead of a whole new session?  I'm guessing the problem with that is that tensorflow doesn't fully free the memory unless you close it?  If so, we should re-evaluate this when we move to 2.x, maybe they have fixed some things like that.

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple clears the tf session
+        # and graph at the last buffer.

Review comment:
       Should this say "closes the tf session" rather than clears?

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -312,103 +201,129 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called for the last buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
-
-    def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
+
+    def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
-
         state = starting_image_count
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
+        # We need to assert that the weights should be multiplied by final image count.
+        # So we divide and round the weights so that we can compare them with
+        # the original weights. They will only be equal if the learnt weights
+        # only go up or down by 1 which seems to be the case for our unit test.
+        # There might be a better way to assert this but this is good enough for now.
+        weights = np.rint(state[1:]/self.total_images_per_seg[0]).astype(np.int)
+        self.assertTrue(((weights == self.model_weights).all()))
+
+        self.assertEqual(ending_image_count, image_count)
+
+    def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg,
+                                                          last_iteration = False,
+                                                          **kwargs):
+        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+        ending_image_count = len(self.dependent_var_int)
+
+        state = [0,0,0]
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
+            self.serialized_weights, self.compile_params, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            last_iteration, None, **kwargs)
+
+        agg_loss, agg_accuracy, image_count = new_state
+
         self.assertEqual(ending_image_count, image_count)
-        # weights should be multiplied by final image count
-        self.assertTrue((multiplied_weights == weights).all())
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # loss and accuracy should be unchanged
+        self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
 
-        # Non-last iteration Call
-        self.subject.K.set_session.reset_mock()
-        self.subject.K.clear_session.reset_mock()
+    def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg,
+                                                     last_iteration = False,
+                                                     **kwargs):
+        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+        state = [self.loss * starting_image_count,
+                 self.accuracy * starting_image_count,
+                 starting_image_count]
 
-        state = starting_image_count
-        new_state = self.subject.fit_transition(
-            state, self.dependent_var, self.independent_var,
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, False, **k)
+            self.model.to_json(),
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            last_iteration, **kwargs)
+        agg_loss, agg_accuracy, image_count = new_state
 
-        state = np.fromstring(new_state, dtype=np.float32)
-        image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
-        # weights should be multiplied by final image count
-        self.assertTrue((multiplied_weights == weights).all())
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must get called for the last buffer in gpdb,
-        #  but not in postgres
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-
-    def test_fit_transition_multiple_model_no_cache_last_buffer_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        # loss and accuracy should be unchanged
+        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
 
-        starting_image_count = 2*len(self.dependent_var_int)
+    def _test_internal_keras_eval_transition_middle_buffer(self,
+                                                           last_iteration = False,
+                                                           **kwargs):
+        starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+
+        state = [self.loss * starting_image_count,
+                 self.accuracy * starting_image_count, starting_image_count]
+
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            last_iteration, **kwargs)
+
+        agg_loss, agg_accuracy, image_count = new_state
+
+        self.assertEqual(ending_image_count, image_count)
+        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
+
+    def _test_fit_transition_multiple_model_no_cache_last_buffer_pass(self,
+                                                                     **kwargs):
+        starting_image_count = 2*len(self.dependent_var_int)
 
         state = starting_image_count
         new_state = self.subject.fit_transition(
-            state   , self.dependent_var, self.independent_var,
+            state , self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            True, **kwargs)

Review comment:
       "todo-remove" : is this a note?

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)

Review comment:
       I want to make sure I understand this.  We want eval and fit to use the same session & graph.  But after the last call to eval, it should have closed the session and destroyed the graph.  So I guess we are checking that they actually got destroyed and a brand new session and graph should be created when we call `get_session()` and `get_default_graph()` now?
   
   If that's right, let's add a comment explaining it.  And if it's not right, definitely add one ;-) 

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################

Review comment:
       These tests are great, I love how this is organized!  Much cleaner to read now.

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple clears the tf session
+        # and graph at the last buffer.
+        # It is fetched prior to calling the fit_transition(from fit_multiple) as when we create
+        # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph.
+        # This enables us to do the not equals assert.
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+        first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        # iteration 2 (last iteration)
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+        last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+    def test_eval_multiple_any_iteration(self):
+        # This test tests 2 things:
+        # 1. Calling eval_transition from fit_multiple
+        # 2. Calling eval_transition from evaluate directly
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and graph
+        # It is fetched prior to calling the eval_transition as when we create a session inside
+        # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1.
+        # This enables us to do the not equals assert.
+        eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+        eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs)
+        eval_iter_keras_sess2 = self.subject.K.get_session()
+        eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+        self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+        self._assert_gd_cleared(GD)
+
+    def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False, **kwargs):
+        self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+                                                               **kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+        eval_first_buffer_keras_sess = kwargs['GD']['sess']
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+        eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+        if not called_from_fit_multiple:
+            self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+            self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+        self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+                                                                **kwargs )
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+        self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+                                                              **kwargs)
+        if final_iteration:
+            self._assert_gd_cleared(kwargs['GD'])
+            self.assertTrue(eval_first_buffer_keras_sess._closed)
+        else:
+            self._assert_gd_is_valid(kwargs['GD'])
+            self.assertFalse(eval_first_buffer_keras_sess._closed)
+        return eval_first_buffer_keras_sess
+
+    def _run_fit_iteration(self, **kwargs):
+        self._test_fit_transition_first_buffer_pass(**kwargs)
+        gd_first_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_first_buffer)
+        iter_sess = gd_first_buffer['sess']
+        self.assertFalse(iter_sess._closed)
+        self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+        self._test_fit_transition_middle_buffer_pass(**kwargs)
+        gd_middle_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_middle_buffer)
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_last_buffer_pass(**kwargs)
+        gd_last_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_last_buffer)
+        self.assertFalse(iter_sess._closed)
+        return iter_sess
+
+    def _run_fit_multiple_iteration(self, **kwargs):
+        self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        iter_sess = kwargs['GD']['sess']
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+        self._assert_gd_cleared(kwargs['GD'])
+        self.assertTrue(iter_sess._closed)
+        return iter_sess

Review comment:
       Also:  can't we at least keep the session open for eval to use?

##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -312,103 +201,129 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called for the last buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
-
-    def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
+
+    def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
-
         state = starting_image_count
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
+        # We need to assert that the weights should be multiplied by final image count.
+        # So we divide and round the weights so that we can compare them with
+        # the original weights. They will only be equal if the learnt weights
+        # only go up or down by 1 which seems to be the case for our unit test.
+        # There might be a better way to assert this but this is good enough for now.
+        weights = np.rint(state[1:]/self.total_images_per_seg[0]).astype(np.int)
+        self.assertTrue(((weights == self.model_weights).all()))

Review comment:
       I think the weights are not supposed to change at all, since we're mocking calls to keras fit--it should just return the same as what was passed in.
   
   Previously this was done with a helper function at the top of this file called `mult()`.  Used in a few places to do a similar comparison like this:
   ```
   multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
   weights = np.rint(state[1:]).astype(np.int)
   self.assertTrue(weights == multiplied_weights).all())
   ```
   I guess we probably don't need the helper function, so we should remove it.  But I think we should stick to multiplying rather than dividing.  If the result returned is off by less than the number of images per segment, then dividing and rounding will end up giving us a false PASS when it should be a FAIL.   Multiplying and comparing will detect a FAIL even if it's only off by 1.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org