You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nk...@apache.org on 2020/12/09 00:39:12 UTC

[madlib] 03/07: DL: Add unit test to assert for session and graph objects

This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 963568aae5b80803dab81dc10f6ffdc7ece30e89
Author: Ekta Khanna <ek...@vmware.com>
AuthorDate: Wed Dec 2 15:10:56 2020 -0800

    DL: Add unit test to assert for session and graph objects
    
    JIRA: MADLIB-1438
    
    Prior to this commit, we mocked the calls to K.set_session and
    K.clear_session. As part of this commit, we let Kears/TF create and set
    the sessions to validate the session/graph behavior.
    
    Co-authored-by: Ekta Khanna <ek...@vmware.com>
---
 .../test/unit_tests/test_madlib_keras.py_in        | 828 +++++++++------------
 1 file changed, 355 insertions(+), 473 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index a97e07e..1b0ee8d 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -45,7 +45,7 @@ except:
 def mult(k,arr):
     return [ k*a for a in arr ]
 
-class MadlibKerasFitTestCase(unittest.TestCase):
+class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
         patches = {
@@ -71,6 +71,8 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.model_weights = [3,4,5,6]
         self.serialized_weights = np.array(self.model_weights, dtype=np.float32
                                            ).tostring()
+        self.loss = 0.5947071313858032
+        self.accuracy = 1.0
         self.dist_key_mapping = [0,1,2]
         self.accessible_gpus_for_seg = [0]
 
@@ -89,100 +91,52 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         self.dummy_prev_weights = 'dummy weights'
 
+        # Mock calls to tf.keras fit
+        model_class = self.subject.tf.keras.Model
+        model_class.fit = MagicMock()
+
     def tearDown(self):
         self.module_patcher.stop()
+        self.subject.K.clear_session()
 
-    def _test_fit_transition_first_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+    def _test_fit_transition_first_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
-        starting_image_count = 0
         ending_image_count = len(self.dependent_var_int)
 
-        # last iteration Call
         previous_state = np.array(self.model_weights, dtype=np.float32)
 
-        k = {'SD': {}}
-
         new_state = self.subject.fit_transition(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, previous_state.tostring(), True, **k)
-
+            self.accessible_gpus_for_seg, previous_state.tostring(), "todo-remove", **kwargs)
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the first buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue(k['SD']['segment_model'])
-
-        # Non-last iteration Call
-        self.subject.K.set_session.reset_mock()
-        self.subject.K.clear_session.reset_mock()
-        previous_state = np.array(self.model_weights, dtype=np.float32)
-
-        k = {'SD' : {}}
-
-        new_state = self.subject.fit_transition(
-            None, self.dependent_var, self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, previous_state.tostring(), False, **k)
-
-        image_count = new_state
 
-        self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the first buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue(k['SD']['segment_model'])
-
-    def test_fit_transition_multiple_model_no_cache_first_buffer_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-        starting_image_count = 0
+    def _test_fit_transition_multiple_model_no_cache_first_buffer_pass(self,
+                                                                      **kwargs):
         ending_image_count = len(self.dependent_var_int)
 
         previous_weights = np.array(self.model_weights, dtype=np.float32)
 
-        k = {'SD': {}}
-
         new_state = self.subject.fit_transition(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), True, True, **k)
+            self.accessible_gpus_for_seg, previous_weights.tostring(),
+            "todo-remove", True, **kwargs)
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session must not be called for the first buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue(k['SD']['segment_model'])
 
     def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-        starting_image_count = 0
         ending_image_count = len(self.dependent_var_int)
 
         previous_weights = np.array(self.model_weights, dtype=np.float32)
 
-        k = {'SD': {}}
-
+        k = {'GD': {}}
         new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
@@ -192,108 +146,47 @@ class MadlibKerasFitTestCase(unittest.TestCase):
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session should only be called for the last row
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session must not be called for the first buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
-
-    def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
+
+    def _test_fit_transition_middle_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
-
-        state = starting_image_count
-        new_state = self.subject.fit_transition(
-            state, self.dependent_var, self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
-
-        image_count = new_state
-        self.assertEqual(ending_image_count, image_count)
-
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-
-        # Non-last iteration Call
-        self.subject.K.set_session.reset_mock()
-        self.subject.K.clear_session.reset_mock()
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
-
         state = starting_image_count
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, False, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **kwargs)
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-
-    def test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
 
+    def _test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self,
+                                                                        **kwargs):
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
-
         state = starting_image_count
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True,
+            **kwargs)
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
 
     def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
@@ -303,7 +196,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+        k = {'GD': {'x_train': x_train, 'y_train': y_train}}
 
         state = starting_image_count
         new_state = self.subject.fit_multiple_transition_caching(
@@ -312,103 +205,125 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called for the last buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
-
-    def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
+
+    def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
-
         state = starting_image_count
+        previous_state = np.array(self.model_weights, dtype=np.float32)
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
-
+            self.accessible_gpus_for_seg, previous_state.tostring(), "todo-remove",
+            **kwargs)
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
+        # We need to assert that the weights should be multiplied by final image count.
+        weights = state[1:]
+        multiplied_weights = mult(self.total_images_per_seg[0], self.model_weights)
+        self.assertTrue((weights == multiplied_weights).all())
+        self.assertEqual(ending_image_count, image_count)
+
+    def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg,
+                                                          last_iteration = False,
+                                                          **kwargs):
+        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+        ending_image_count = len(self.dependent_var_int)
+
+        state = [0,0,0]
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
+            self.serialized_weights, self.compile_params, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            last_iteration, None, **kwargs)
+
+        agg_loss, agg_accuracy, image_count = new_state
+
         self.assertEqual(ending_image_count, image_count)
-        # weights should be multiplied by final image count
-        self.assertTrue((multiplied_weights == weights).all())
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # loss and accuracy should be unchanged
+        self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
 
-        # Non-last iteration Call
-        self.subject.K.set_session.reset_mock()
-        self.subject.K.clear_session.reset_mock()
+    def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg,
+                                                     last_iteration = False,
+                                                     **kwargs):
+        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+        state = [self.loss * starting_image_count,
+                 self.accuracy * starting_image_count,
+                 starting_image_count]
 
-        state = starting_image_count
-        new_state = self.subject.fit_transition(
-            state, self.dependent_var, self.independent_var,
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, False, **k)
+            self.model.to_json(),
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            last_iteration, **kwargs)
+        agg_loss, agg_accuracy, image_count = new_state
 
-        state = np.fromstring(new_state, dtype=np.float32)
-        image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
         self.assertEqual(ending_image_count, image_count)
-        # weights should be multiplied by final image count
-        self.assertTrue((multiplied_weights == weights).all())
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must get called for the last buffer in gpdb,
-        #  but not in postgres
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-
-    def test_fit_transition_multiple_model_no_cache_last_buffer_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        # loss and accuracy should be unchanged
+        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
 
-        starting_image_count = 2*len(self.dependent_var_int)
+    def _test_internal_keras_eval_transition_middle_buffer(self,
+                                                           last_iteration = False,
+                                                           **kwargs):
+        starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+
+        state = [self.loss * starting_image_count,
+                 self.accuracy * starting_image_count, starting_image_count]
+
+        new_state = self.subject.internal_keras_eval_transition(
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
+            'dummy_model_weights', None, 0,
+            self.dist_key_mapping, 0, 4,
+            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            last_iteration, **kwargs)
+
+        agg_loss, agg_accuracy, image_count = new_state
+
+        self.assertEqual(ending_image_count, image_count)
+        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
+
+    def _test_fit_transition_multiple_model_no_cache_last_buffer_pass(self,
+                                                                     **kwargs):
+        starting_image_count = 2*len(self.dependent_var_int)
 
         state = starting_image_count
         new_state = self.subject.fit_transition(
-            state   , self.dependent_var, self.independent_var,
+            state , self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            True, **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
         weights = np.rint(state[0:]).astype(np.int)
@@ -417,15 +332,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # fit multiple
         self.assertEqual(len(self.model_weights), len(weights))
 
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-
     def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
@@ -437,16 +344,18 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+        k = {'GD': {'x_train': x_train, 'y_train': y_train}}
 
         state = starting_image_count
+        graph1 = self.subject.tf.get_default_graph()
         new_state = self.subject.fit_multiple_transition_caching(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
-
+        graph2 = self.subject.tf.get_default_graph()
+        self.assertNotEquals(graph1, graph2)
         state = np.fromstring(new_state, dtype=np.float32)
         weights = np.rint(state[0:]).astype(np.int)
 
@@ -454,21 +363,13 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # fit multiple
         self.assertEqual(len(self.model_weights), len(weights))
 
-        # set_session is only called for the last buffer
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must get called for the last buffer
-        self.assertEqual(1, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue(k['SD']['cache_set'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue(k['GD']['cache_set'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
 
     def test_fit_transition_multiple_model_cache_filled_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
@@ -482,15 +383,20 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
-
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                                     '/cpu:0', self.serialized_weights)
+        s1 = self.subject.K.get_session()
+        k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True,
+                    'sess': s1, 'segment_model': self.model}}
+        graph1 = self.subject.tf.get_default_graph()
         new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
-
+        graph2 = self.subject.tf.get_default_graph()
+        self.assertNotEquals(graph1, graph2)
         state = np.fromstring(new_state, dtype=np.float32)
         weights = np.rint(state[0:]).astype(np.int)
 
@@ -498,21 +404,13 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # fit multiple
         self.assertEqual(len(self.model_weights), len(weights))
 
-        # set_session is only called for the last buffer
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must get called for the last buffer
-        self.assertEqual(1, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue(k['SD']['cache_set'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue(k['GD']['cache_set'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
 
     def test_fit_transition_multiple_model_cache_filled_final_training_pass(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
@@ -526,14 +424,16 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
-
+        k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+        graph1 = self.subject.tf.get_default_graph()
         new_state = self.subject.fit_multiple_transition_caching(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+        graph2 = self.subject.tf.get_default_graph()
+        self.assertNotEquals(graph1, graph2)
 
         state = np.fromstring(new_state, dtype=np.float32)
         weights = np.rint(state[0:]).astype(np.int)
@@ -542,14 +442,11 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # fit multiple
         self.assertEqual(len(self.model_weights), len(weights))
 
-        # set_session is only called for the last buffer
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must get called for the last buffer
-        self.assertEqual(1, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue('x_train' not in k['SD'])
-        self.assertTrue('y_train' not in k['SD'])
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue('x_train' not in k['GD'])
+        self.assertTrue('y_train' not in k['GD'])
 
     def test_fit_transition_first_buffer_pass_pg(self):
         self._test_fit_transition_first_buffer_pass(True)
@@ -569,6 +466,204 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        # eval and fit use the same session & graph for all iterations.
+        # After the last call to eval(last iteration), we want to assert
+        # the session was closed and the graph cleared out.
+        # To assert this we call get_session() and get_default_graph()
+        # that will give a new session and a new graph and assert its not
+        # equal to the prev iteration session and graph.
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple closes the tf session
+        # and graph at the last buffer.
+        # It is fetched prior to calling the fit_transition(from fit_multiple) as when we create
+        # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph.
+        # This enables us to do the not equals assert.
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+        first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        # iteration 2 (last iteration)
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+        last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+    def test_eval_multiple_any_iteration(self):
+        # This test tests 2 things:
+        # 1. Calling eval_transition from fit_multiple
+        # 2. Calling eval_transition from evaluate directly
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and graph
+        # It is fetched prior to calling the eval_transition as when we create a session inside
+        # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1.
+        # This enables us to do the not equals assert.
+        eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+        eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs)
+        eval_iter_keras_sess2 = self.subject.K.get_session()
+        eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+        self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+        self._assert_gd_cleared(GD)
+
+    def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False, **kwargs):
+        self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+                                                               **kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+        eval_first_buffer_keras_sess = kwargs['GD']['sess']
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+        eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+        if not called_from_fit_multiple:
+            self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+            self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+        self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+                                                                **kwargs )
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+        self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+                                                              **kwargs)
+        if final_iteration:
+            self._assert_gd_cleared(kwargs['GD'])
+            self.assertTrue(eval_first_buffer_keras_sess._closed)
+        else:
+            self._assert_gd_is_valid(kwargs['GD'])
+            self.assertFalse(eval_first_buffer_keras_sess._closed)
+        return eval_first_buffer_keras_sess
+
+    def _run_fit_iteration(self, **kwargs):
+        self._test_fit_transition_first_buffer_pass(**kwargs)
+        gd_first_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_first_buffer)
+        iter_sess = gd_first_buffer['sess']
+        self.assertFalse(iter_sess._closed)
+        self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+        self._test_fit_transition_middle_buffer_pass(**kwargs)
+        gd_middle_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_middle_buffer)
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_last_buffer_pass(**kwargs)
+        gd_last_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_last_buffer)
+        self.assertFalse(iter_sess._closed)
+        return iter_sess
+
+    def _run_fit_multiple_iteration(self, **kwargs):
+        self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        iter_sess = kwargs['GD']['sess']
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+        self._assert_gd_cleared(kwargs['GD'])
+        self.assertTrue(iter_sess._closed)
+        return iter_sess
+
+    def _assert_keras_session_same_as_gd_session(self, gd):
+        sess = self.subject.K.get_session()
+        self.assertEquals(sess, gd['sess'])
+
+    def _assert_gd_cleared(self, gd):
+        self.assertEquals(0, len(gd.keys()))
+
+    def _assert_gd_is_valid(self, gd):
+        self.assertTrue(gd['sess'])
+        self.assertTrue(gd['segment_model'])
+
+    ################################################################
+
     def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
         k = {}
         self.assertEqual('dummy_state',
@@ -1580,7 +1675,7 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
             "This is not valid PostgresSQL: SELECT {}[1]".format(metrics)
         )
 
-class MadlibKerasEvaluationTestCase(unittest.TestCase):
+class MadlibKerasEvaluationMergeFinalTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
         patches = {
@@ -1628,228 +1723,6 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
     def tearDown(self):
         self.module_patcher.stop()
 
-    def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg):
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
-        starting_image_count = 0
-        ending_image_count = len(self.dependent_var_int)
-
-        # last iteration call
-
-        k = {'SD' : {}}
-        state = [0,0,0]
-        new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(),
-            self.serialized_weights, self.compile_params, 0,
-            self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
-
-        agg_loss, agg_accuracy, image_count = new_state
-
-        self.assertEqual(ending_image_count, image_count)
-        # Call set_session once for gpdb (but not for postgres)
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # loss and accuracy should be unchanged
-        self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
-        self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
-        # Clear session and sess.close must not get called for the first buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue(k['SD']['segment_model'])
-
-        # Non-final call
-
-        self.subject.K.set_session.reset_mock()
-        self.subject.K.clear_session.reset_mock()
-        k = {'SD' : {}}
-        state = [0,0,0]
-
-        new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(),
-            self.serialized_weights, self.compile_params, 0,
-            self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
-        agg_loss, agg_accuracy, image_count = new_state
-
-        self.assertEqual(ending_image_count, image_count)
-        # set_session must not get called for the first buffer
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # loss and accuracy should be unchanged
-        self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
-        self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
-        # Clear session and sess.close must not get called for the first buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue(k['SD']['segment_model'])
-
-    def _test_internal_keras_eval_transition_middle_buffer(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
-
-        starting_image_count = len(self.dependent_var_int)
-        ending_image_count = starting_image_count + len(self.dependent_var_int)
-
-        # last iteration call
-
-        k = {'SD' : {}}
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-
-        state = [self.loss * starting_image_count, self.accuracy * starting_image_count, starting_image_count]
-        k['SD']['segment_model'] = self.model
-        k['SD']['sess'] = Mock()
-
-        new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(),
-            'dummy_model_weights', None, 0,
-            self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
-
-        agg_loss, agg_accuracy, image_count = new_state
-
-        self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-         # loss and accuracy should be unchanged
-        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
-        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-
-        # Non-last iteration call
-
-        self.subject.K.set_session.reset_mock()
-        self.subject.K.clear_session.reset_mock()
-        k = {'SD' : {}}
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-
-        state = [self.loss * starting_image_count, self.accuracy * starting_image_count, starting_image_count]
-        k['SD']['segment_model'] = self.model
-        k['SD']['sess'] = Mock()
-
-        new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(),
-            'dummy_model_weights', None, 0,
-            self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
-        agg_loss, agg_accuracy, image_count = new_state
-
-        self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-         # loss and accuracy should be unchanged
-        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
-        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-
-    def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
-
-        starting_image_count = 2*len(self.dependent_var_int)
-        ending_image_count = starting_image_count + len(self.dependent_var_int)
-
-        k = {'SD' : {}}
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-
-        state = [self.loss * starting_image_count,
-                 self.accuracy * starting_image_count,
-                 starting_image_count]
-
-        k['SD']['segment_model'] = self.model
-        k['SD']['sess'] = Mock()
-
-        new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(),
-            'dummy_model_weights', None, 0,
-            self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
-        agg_loss, agg_accuracy, image_count = new_state
-
-        self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # loss and accuracy should be unchanged
-        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
-        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
-        # Clear session and sess.close must get called for the last buffer in gpdb,
-        #  but not in postgres
-        self.assertEqual(1, self.subject.K.clear_session.call_count)
-
-        # Non-final call
-
-        self.subject.K.set_session.reset_mock()
-        self.subject.K.clear_session.reset_mock()
-        k = {'SD' : {}}
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-
-        state = [self.loss * starting_image_count,
-                 self.accuracy * starting_image_count,
-                 starting_image_count]
-
-        k['SD']['segment_model'] = self.model
-        k['SD']['sess'] = Mock()
-
-        new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
-            self.model.to_json(),
-            'dummy_model_weights', None, 0,
-            self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
-
-        agg_loss, agg_accuracy, image_count = new_state
-
-        self.assertEqual(ending_image_count, image_count)
-        # set_session is always called
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # loss and accuracy should be unchanged
-        self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
-        self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
-        # Clear session and sess.close must not get called in non-final iterations
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-
-    def test_internal_keras_eval_transition_first_buffer_pg(self):
-        self._test_internal_keras_eval_transition_first_buffer(True)
-
-    def test_internal_keras_eval_transition_first_buffer_gpdb(self):
-        self._test_internal_keras_eval_transition_first_buffer(False)
-
-    def test_internal_keras_eval_transition_middle_buffer_pg(self):
-        self._test_internal_keras_eval_transition_middle_buffer(True)
-
-    def test_internal_keras_eval_transition_middle_buffer_gpdb(self):
-        self._test_internal_keras_eval_transition_middle_buffer(False)
-
-    def test_internal_keras_eval_transition_last_buffer_pg(self):
-        self._test_internal_keras_eval_transition_last_buffer(True)
-
-    def test_internal_keras_eval_transition_last_buffer_gpdb(self):
-        self._test_internal_keras_eval_transition_last_buffer(False)
-
     def test_internal_keras_eval_merge(self):
         image_count = self.total_images_per_seg[0]
         state1 = [3.0*self.loss, 3.0*self.accuracy, image_count]
@@ -1919,7 +1792,16 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         self.module_patcher.stop()
 
 if __name__ == '__main__':
-    from tensorflow import keras
+    # Do not move any of the tensorflow imports outside of this block. This is
+    # because importing tensorflow.keras.models/layers changes the `__name__`
+    # variable because of which the if condition fails and the unit tests don't
+    # get run
+
+    # turn off verbose output
+    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+    import tensorflow as tf
+    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
     from tensorflow.keras.models import *
     from tensorflow.keras.layers import *
     unittest.main()