You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nk...@apache.org on 2020/12/09 00:39:12 UTC
[madlib] 03/07: DL: Add unit test to assert for session and graph
objects
This is an automated email from the ASF dual-hosted git repository.
nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 963568aae5b80803dab81dc10f6ffdc7ece30e89
Author: Ekta Khanna <ek...@vmware.com>
AuthorDate: Wed Dec 2 15:10:56 2020 -0800
DL: Add unit test to assert for session and graph objects
JIRA: MADLIB-1438
Prior to this commit, we mocked the calls to K.set_session and
K.clear_session. As part of this commit, we let Kears/TF create and set
the sessions to validate the session/graph behavior.
Co-authored-by: Ekta Khanna <ek...@vmware.com>
---
.../test/unit_tests/test_madlib_keras.py_in | 828 +++++++++------------
1 file changed, 355 insertions(+), 473 deletions(-)
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index a97e07e..1b0ee8d 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -45,7 +45,7 @@ except:
def mult(k,arr):
return [ k*a for a in arr ]
-class MadlibKerasFitTestCase(unittest.TestCase):
+class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {
@@ -71,6 +71,8 @@ class MadlibKerasFitTestCase(unittest.TestCase):
self.model_weights = [3,4,5,6]
self.serialized_weights = np.array(self.model_weights, dtype=np.float32
).tostring()
+ self.loss = 0.5947071313858032
+ self.accuracy = 1.0
self.dist_key_mapping = [0,1,2]
self.accessible_gpus_for_seg = [0]
@@ -89,100 +91,52 @@ class MadlibKerasFitTestCase(unittest.TestCase):
self.dummy_prev_weights = 'dummy weights'
+ # Mock calls to tf.keras fit
+ model_class = self.subject.tf.keras.Model
+ model_class.fit = MagicMock()
+
def tearDown(self):
self.module_patcher.stop()
+ self.subject.K.clear_session()
- def _test_fit_transition_first_buffer_pass(self, is_platform_pg):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
+ def _test_fit_transition_first_buffer_pass(self, is_platform_pg, **kwargs):
self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
- starting_image_count = 0
ending_image_count = len(self.dependent_var_int)
- # last iteration Call
previous_state = np.array(self.model_weights, dtype=np.float32)
- k = {'SD': {}}
-
new_state = self.subject.fit_transition(
None, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), self.compile_params, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, previous_state.tostring(), True, **k)
-
+ self.accessible_gpus_for_seg, previous_state.tostring(), "todo-remove", **kwargs)
image_count = new_state
self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must not get called for the first buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
- self.assertTrue(k['SD']['segment_model'])
-
- # Non-last iteration Call
- self.subject.K.set_session.reset_mock()
- self.subject.K.clear_session.reset_mock()
- previous_state = np.array(self.model_weights, dtype=np.float32)
-
- k = {'SD' : {}}
-
- new_state = self.subject.fit_transition(
- None, self.dependent_var, self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(), self.compile_params, self.fit_params, 0,
- self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, previous_state.tostring(), False, **k)
-
- image_count = new_state
- self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must not get called for the first buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
- self.assertTrue(k['SD']['segment_model'])
-
- def test_fit_transition_multiple_model_no_cache_first_buffer_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
- starting_image_count = 0
+ def _test_fit_transition_multiple_model_no_cache_first_buffer_pass(self,
+ **kwargs):
ending_image_count = len(self.dependent_var_int)
previous_weights = np.array(self.model_weights, dtype=np.float32)
- k = {'SD': {}}
-
new_state = self.subject.fit_transition(
None, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), self.compile_params, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, previous_weights.tostring(), True, True, **k)
+ self.accessible_gpus_for_seg, previous_weights.tostring(),
+ "todo-remove", True, **kwargs)
image_count = new_state
self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session must not be called for the first buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
- self.assertTrue(k['SD']['segment_model'])
def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
- starting_image_count = 0
ending_image_count = len(self.dependent_var_int)
previous_weights = np.array(self.model_weights, dtype=np.float32)
- k = {'SD': {}}
-
+ k = {'GD': {}}
new_state = self.subject.fit_multiple_transition_caching(
None, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
@@ -192,108 +146,47 @@ class MadlibKerasFitTestCase(unittest.TestCase):
image_count = new_state
self.assertEqual(ending_image_count, image_count)
- # set_session should only be called for the last row
- self.assertEqual(0, self.subject.K.set_session.call_count)
- # Clear session must not be called for the first buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
- self.assertTrue('segment_model' not in k['SD'])
- self.assertTrue('cache_set' not in k['SD'])
- self.assertTrue(k['SD']['x_train'])
- self.assertTrue(k['SD']['y_train'])
-
- def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
+ self.assertTrue('sess' not in k['GD'])
+ self.assertTrue('segment_model' not in k['GD'])
+ self.assertTrue('cache_set' not in k['GD'])
+ self.assertTrue(k['GD']['x_train'])
+ self.assertTrue(k['GD']['y_train'])
+
+ def _test_fit_transition_middle_buffer_pass(self, is_platform_pg, **kwargs):
self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
starting_image_count = len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
- # last iteration Call
-
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
- k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
-
- state = starting_image_count
- new_state = self.subject.fit_transition(
- state, self.dependent_var, self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(), None, self.fit_params, 0,
- self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
-
- image_count = new_state
- self.assertEqual(ending_image_count, image_count)
-
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must not get called for the middle buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
-
- # Non-last iteration Call
- self.subject.K.set_session.reset_mock()
- self.subject.K.clear_session.reset_mock()
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
- k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
-
state = starting_image_count
new_state = self.subject.fit_transition(
state, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), None, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, self.dummy_prev_weights, False, **k)
+ self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **kwargs)
image_count = new_state
self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must not get called for the middle buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
-
- def test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
+ def _test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self,
+ **kwargs):
starting_image_count = len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
- # last iteration Call
-
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
- k = {'SD': {'segment_model': self.model, 'sess': Mock()}}
-
state = starting_image_count
new_state = self.subject.fit_transition(
state, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), None, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True, **k)
+ self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True,
+ **kwargs)
image_count = new_state
self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must not get called for the middle buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
-
starting_image_count = len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
@@ -303,7 +196,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
- k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+ k = {'GD': {'x_train': x_train, 'y_train': y_train}}
state = starting_image_count
new_state = self.subject.fit_multiple_transition_caching(
@@ -312,103 +205,125 @@ class MadlibKerasFitTestCase(unittest.TestCase):
self.model.to_json(), self.compile_params, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
image_count = new_state
self.assertEqual(ending_image_count, image_count)
- # set_session is only called for the last buffer
- self.assertEqual(0, self.subject.K.set_session.call_count)
- # Clear session and sess.close must not get called for the middle buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
- self.assertTrue('segment_model' not in k['SD'])
- self.assertTrue('cache_set' not in k['SD'])
- self.assertTrue(k['SD']['x_train'])
- self.assertTrue(k['SD']['y_train'])
-
- def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
+ self.assertTrue('sess' not in k['GD'])
+ self.assertTrue('segment_model' not in k['GD'])
+ self.assertTrue('cache_set' not in k['GD'])
+ self.assertTrue(k['GD']['x_train'])
+ self.assertTrue(k['GD']['y_train'])
+
+ def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
starting_image_count = 2*len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
- # last iteration Call
-
- multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
- k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
-
state = starting_image_count
+ previous_state = np.array(self.model_weights, dtype=np.float32)
new_state = self.subject.fit_transition(
state, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), None, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
-
+ self.accessible_gpus_for_seg, previous_state.tostring(), "todo-remove",
+ **kwargs)
state = np.fromstring(new_state, dtype=np.float32)
image_count = state[0]
- weights = np.rint(state[1:]).astype(np.int)
+ # We need to assert that the weights should be multiplied by final image count.
+ weights = state[1:]
+ multiplied_weights = mult(self.total_images_per_seg[0], self.model_weights)
+ self.assertTrue((weights == multiplied_weights).all())
+ self.assertEqual(ending_image_count, image_count)
+
+ def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg,
+ last_iteration = False,
+ **kwargs):
+ self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+ ending_image_count = len(self.dependent_var_int)
+
+ state = [0,0,0]
+ new_state = self.subject.internal_keras_eval_transition(
+ state, self.dependent_var , self.independent_var,
+ self.dependent_var_shape, self.independent_var_shape,
+ self.model.to_json(),
+ self.serialized_weights, self.compile_params, 0,
+ self.dist_key_mapping, 0, 4,
+ self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+ last_iteration, None, **kwargs)
+
+ agg_loss, agg_accuracy, image_count = new_state
+
self.assertEqual(ending_image_count, image_count)
- # weights should be multiplied by final image count
- self.assertTrue((multiplied_weights == weights).all())
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
+ # loss and accuracy should be unchanged
+ self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
+ self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
- # Non-last iteration Call
- self.subject.K.set_session.reset_mock()
- self.subject.K.clear_session.reset_mock()
+ def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg,
+ last_iteration = False,
+ **kwargs):
+ self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
- multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
+ starting_image_count = 2*len(self.dependent_var_int)
+ ending_image_count = starting_image_count + len(self.dependent_var_int)
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
- k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+ state = [self.loss * starting_image_count,
+ self.accuracy * starting_image_count,
+ starting_image_count]
- state = starting_image_count
- new_state = self.subject.fit_transition(
- state, self.dependent_var, self.independent_var,
+ new_state = self.subject.internal_keras_eval_transition(
+ state, self.dependent_var , self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(), None, self.fit_params, 0,
- self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, self.dummy_prev_weights, False, **k)
+ self.model.to_json(),
+ 'dummy_model_weights', None, 0,
+ self.dist_key_mapping, 0, 4,
+ self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+ last_iteration, **kwargs)
+ agg_loss, agg_accuracy, image_count = new_state
- state = np.fromstring(new_state, dtype=np.float32)
- image_count = state[0]
- weights = np.rint(state[1:]).astype(np.int)
self.assertEqual(ending_image_count, image_count)
- # weights should be multiplied by final image count
- self.assertTrue((multiplied_weights == weights).all())
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must get called for the last buffer in gpdb,
- # but not in postgres
- self.assertEqual(0, self.subject.K.clear_session.call_count)
-
- def test_fit_transition_multiple_model_no_cache_last_buffer_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
+ # loss and accuracy should be unchanged
+ self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+ self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
- starting_image_count = 2*len(self.dependent_var_int)
+ def _test_internal_keras_eval_transition_middle_buffer(self,
+ last_iteration = False,
+ **kwargs):
+ starting_image_count = len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
- # last iteration Call
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
- k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
+
+ state = [self.loss * starting_image_count,
+ self.accuracy * starting_image_count, starting_image_count]
+
+ new_state = self.subject.internal_keras_eval_transition(
+ state, self.dependent_var , self.independent_var,
+ self.dependent_var_shape, self.independent_var_shape,
+ self.model.to_json(),
+ 'dummy_model_weights', None, 0,
+ self.dist_key_mapping, 0, 4,
+ self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+ last_iteration, **kwargs)
+
+ agg_loss, agg_accuracy, image_count = new_state
+
+ self.assertEqual(ending_image_count, image_count)
+ self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
+ self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
+
+ def _test_fit_transition_multiple_model_no_cache_last_buffer_pass(self,
+ **kwargs):
+ starting_image_count = 2*len(self.dependent_var_int)
state = starting_image_count
new_state = self.subject.fit_transition(
- state , self.dependent_var, self.independent_var,
+ state , self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), None, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
- self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True, **k)
+ self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+ True, **kwargs)
state = np.fromstring(new_state, dtype=np.float32)
weights = np.rint(state[0:]).astype(np.int)
@@ -417,15 +332,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
# fit multiple
self.assertEqual(len(self.model_weights), len(weights))
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
-
def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
-
starting_image_count = 2*len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
@@ -437,16 +344,18 @@ class MadlibKerasFitTestCase(unittest.TestCase):
x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
- k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+ k = {'GD': {'x_train': x_train, 'y_train': y_train}}
state = starting_image_count
+ graph1 = self.subject.tf.get_default_graph()
new_state = self.subject.fit_multiple_transition_caching(
state, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), self.compile_params, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
-
+ graph2 = self.subject.tf.get_default_graph()
+ self.assertNotEquals(graph1, graph2)
state = np.fromstring(new_state, dtype=np.float32)
weights = np.rint(state[0:]).astype(np.int)
@@ -454,21 +363,13 @@ class MadlibKerasFitTestCase(unittest.TestCase):
# fit multiple
self.assertEqual(len(self.model_weights), len(weights))
- # set_session is only called for the last buffer
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must get called for the last buffer
- self.assertEqual(1, self.subject.K.clear_session.call_count)
- self.assertTrue('segment_model' not in k['SD'])
- self.assertTrue(k['SD']['cache_set'])
- self.assertTrue(k['SD']['x_train'])
- self.assertTrue(k['SD']['y_train'])
+ self.assertTrue('sess' not in k['GD'])
+ self.assertTrue('segment_model' not in k['GD'])
+ self.assertTrue(k['GD']['cache_set'])
+ self.assertTrue(k['GD']['x_train'])
+ self.assertTrue(k['GD']['y_train'])
def test_fit_transition_multiple_model_cache_filled_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
-
starting_image_count = 2*len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
@@ -482,15 +383,20 @@ class MadlibKerasFitTestCase(unittest.TestCase):
x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
- k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
-
+ self.subject.compile_and_set_weights(self.model, self.compile_params,
+ '/cpu:0', self.serialized_weights)
+ s1 = self.subject.K.get_session()
+ k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True,
+ 'sess': s1, 'segment_model': self.model}}
+ graph1 = self.subject.tf.get_default_graph()
new_state = self.subject.fit_multiple_transition_caching(
None, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), self.compile_params, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
-
+ graph2 = self.subject.tf.get_default_graph()
+ self.assertNotEquals(graph1, graph2)
state = np.fromstring(new_state, dtype=np.float32)
weights = np.rint(state[0:]).astype(np.int)
@@ -498,21 +404,13 @@ class MadlibKerasFitTestCase(unittest.TestCase):
# fit multiple
self.assertEqual(len(self.model_weights), len(weights))
- # set_session is only called for the last buffer
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must get called for the last buffer
- self.assertEqual(1, self.subject.K.clear_session.call_count)
- self.assertTrue('segment_model' not in k['SD'])
- self.assertTrue(k['SD']['cache_set'])
- self.assertTrue(k['SD']['x_train'])
- self.assertTrue(k['SD']['y_train'])
+ self.assertTrue('sess' not in k['GD'])
+ self.assertTrue('segment_model' not in k['GD'])
+ self.assertTrue(k['GD']['cache_set'])
+ self.assertTrue(k['GD']['x_train'])
+ self.assertTrue(k['GD']['y_train'])
def test_fit_transition_multiple_model_cache_filled_final_training_pass(self):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
-
starting_image_count = 2*len(self.dependent_var_int)
ending_image_count = starting_image_count + len(self.dependent_var_int)
@@ -526,14 +424,16 @@ class MadlibKerasFitTestCase(unittest.TestCase):
x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
- k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
-
+ k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+ graph1 = self.subject.tf.get_default_graph()
new_state = self.subject.fit_multiple_transition_caching(
None, self.dependent_var, self.independent_var,
self.dependent_var_shape, self.independent_var_shape,
self.model.to_json(), self.compile_params, self.fit_params, 0,
self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+ graph2 = self.subject.tf.get_default_graph()
+ self.assertNotEquals(graph1, graph2)
state = np.fromstring(new_state, dtype=np.float32)
weights = np.rint(state[0:]).astype(np.int)
@@ -542,14 +442,11 @@ class MadlibKerasFitTestCase(unittest.TestCase):
# fit multiple
self.assertEqual(len(self.model_weights), len(weights))
- # set_session is only called for the last buffer
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # Clear session and sess.close must get called for the last buffer
- self.assertEqual(1, self.subject.K.clear_session.call_count)
- self.assertTrue('segment_model' not in k['SD'])
- self.assertTrue('cache_set' not in k['SD'])
- self.assertTrue('x_train' not in k['SD'])
- self.assertTrue('y_train' not in k['SD'])
+ self.assertTrue('sess' not in k['GD'])
+ self.assertTrue('segment_model' not in k['GD'])
+ self.assertTrue('cache_set' not in k['GD'])
+ self.assertTrue('x_train' not in k['GD'])
+ self.assertTrue('y_train' not in k['GD'])
def test_fit_transition_first_buffer_pass_pg(self):
self._test_fit_transition_first_buffer_pass(True)
@@ -569,6 +466,204 @@ class MadlibKerasFitTestCase(unittest.TestCase):
def test_fit_transition_last_buffer_pass_gpdb(self):
self._test_fit_transition_last_buffer_pass(False)
+ ############### GRAPH AND SESSION TESTS ################################
+ def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+ kwargs = {'GD': {}}
+ GD = kwargs['GD']
+
+ ######################### fit for 2 iterations ##########
+ # iteration 1
+ first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+ self._assert_keras_session_same_as_gd_session(GD)
+
+ first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+ # iteration 2 (last iteration)
+ last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+ self._assert_keras_session_same_as_gd_session(GD)
+
+ last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+ self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+ self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+ ###################### eval transition for last iteration ###########
+ self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+ eval_last_iter_keras_sess = self.subject.K.get_session()
+ eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+ self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+ self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+ self._assert_gd_cleared(GD)
+
+ def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+ kwargs = {'GD': {}}
+ GD = kwargs['GD']
+
+ ######################### fit + eval for 2 iterations ##########
+ # iteration 1 fit
+ first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+ self._assert_keras_session_same_as_gd_session(GD)
+
+ first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+ # iteration 1 eval
+ self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+ self._assert_keras_session_same_as_gd_session(GD)
+
+ eval_first_iter_keras_sess = self.subject.K.get_session()
+ eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+ self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+ self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+ # iteration 2 fit (last iteration)
+ last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+ self._assert_keras_session_same_as_gd_session(GD)
+
+ last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+ self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+ self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+ # iteration 2 eval (last iteration)
+ # eval and fit use the same session & graph for all iterations.
+ # After the last call to eval(last iteration), we want to assert
+ # the session was closed and the graph cleared out.
+ # To assert this we call get_session() and get_default_graph()
+ # that will give a new session and a new graph and assert its not
+ # equal to the prev iteration session and graph.
+ self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+ eval_last_iter_keras_sess = self.subject.K.get_session()
+ eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+ self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+ self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+ self._assert_gd_cleared(GD)
+
+ def test_fit_multiple_2_iterations(self):
+ kwargs = {'GD': {}}
+ GD = kwargs['GD']
+
+ ############ fit multiple for 2 iterations ##########
+ # iteration 1
+ # first_iter_tf_graph is used to assert that calling fit_multiple closes the tf session
+ # and graph at the last buffer.
+ # It is fetched prior to calling the fit_transition(from fit_multiple) as when we create
+ # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph.
+ # This enables us to do the not equals assert.
+ first_iter_tf_graph = self.subject.tf.get_default_graph()
+ first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+ self._assert_gd_cleared(GD)
+
+ # iteration 2 (last iteration)
+ last_iter_tf_graph = self.subject.tf.get_default_graph()
+ last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+ self._assert_gd_cleared(GD)
+
+ self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+ self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+ def test_eval_multiple_any_iteration(self):
+ # This test tests 2 things:
+ # 1. Calling eval_transition from fit_multiple
+ # 2. Calling eval_transition from evaluate directly
+ kwargs = {'GD': {}}
+ GD = kwargs['GD']
+
+ # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and graph
+ # It is fetched prior to calling the eval_transition as when we create a session inside
+ # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1.
+ # This enables us to do the not equals assert.
+ eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+ eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs)
+ eval_iter_keras_sess2 = self.subject.K.get_session()
+ eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+ self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+ self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+ self._assert_gd_cleared(GD)
+
+ def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False, **kwargs):
+ self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+ **kwargs)
+ self._assert_gd_is_valid(kwargs['GD'])
+ self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+ eval_first_buffer_keras_sess = kwargs['GD']['sess']
+ self.assertFalse(eval_first_buffer_keras_sess._closed)
+ eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+ if not called_from_fit_multiple:
+ self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+ self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+ self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+ **kwargs )
+ self._assert_gd_is_valid(kwargs['GD'])
+ self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+ self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+ self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+ **kwargs)
+ if final_iteration:
+ self._assert_gd_cleared(kwargs['GD'])
+ self.assertTrue(eval_first_buffer_keras_sess._closed)
+ else:
+ self._assert_gd_is_valid(kwargs['GD'])
+ self.assertFalse(eval_first_buffer_keras_sess._closed)
+ return eval_first_buffer_keras_sess
+
+ def _run_fit_iteration(self, **kwargs):
+ self._test_fit_transition_first_buffer_pass(**kwargs)
+ gd_first_buffer = kwargs['GD']
+ self._assert_gd_is_valid(gd_first_buffer)
+ iter_sess = gd_first_buffer['sess']
+ self.assertFalse(iter_sess._closed)
+ self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+ self._test_fit_transition_middle_buffer_pass(**kwargs)
+ gd_middle_buffer = kwargs['GD']
+ self._assert_gd_is_valid(gd_middle_buffer)
+ self.assertFalse(iter_sess._closed)
+
+ self._test_fit_transition_last_buffer_pass(**kwargs)
+ gd_last_buffer = kwargs['GD']
+ self._assert_gd_is_valid(gd_last_buffer)
+ self.assertFalse(iter_sess._closed)
+ return iter_sess
+
+ def _run_fit_multiple_iteration(self, **kwargs):
+ self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+ self._assert_gd_is_valid(kwargs['GD'])
+ self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+ iter_sess = kwargs['GD']['sess']
+ self.assertFalse(iter_sess._closed)
+
+ self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+ self._assert_gd_is_valid(kwargs['GD'])
+ self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+ self.assertFalse(iter_sess._closed)
+
+ self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+ self._assert_gd_cleared(kwargs['GD'])
+ self.assertTrue(iter_sess._closed)
+ return iter_sess
+
+ def _assert_keras_session_same_as_gd_session(self, gd):
+ sess = self.subject.K.get_session()
+ self.assertEquals(sess, gd['sess'])
+
+ def _assert_gd_cleared(self, gd):
+ self.assertEquals(0, len(gd.keys()))
+
+ def _assert_gd_is_valid(self, gd):
+ self.assertTrue(gd['sess'])
+ self.assertTrue(gd['segment_model'])
+
+ ################################################################
+
def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
k = {}
self.assertEqual('dummy_state',
@@ -1580,7 +1675,7 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
"This is not valid PostgresSQL: SELECT {}[1]".format(metrics)
)
-class MadlibKerasEvaluationTestCase(unittest.TestCase):
+class MadlibKerasEvaluationMergeFinalTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {
@@ -1628,228 +1723,6 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
def tearDown(self):
self.module_patcher.stop()
- def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg):
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
- self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
- starting_image_count = 0
- ending_image_count = len(self.dependent_var_int)
-
- # last iteration call
-
- k = {'SD' : {}}
- state = [0,0,0]
- new_state = self.subject.internal_keras_eval_transition(
- state, self.dependent_var , self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(),
- self.serialized_weights, self.compile_params, 0,
- self.dist_key_mapping, 0, 4,
- self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
-
- agg_loss, agg_accuracy, image_count = new_state
-
- self.assertEqual(ending_image_count, image_count)
- # Call set_session once for gpdb (but not for postgres)
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # loss and accuracy should be unchanged
- self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
- self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
- # Clear session and sess.close must not get called for the first buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
- self.assertTrue(k['SD']['segment_model'])
-
- # Non-final call
-
- self.subject.K.set_session.reset_mock()
- self.subject.K.clear_session.reset_mock()
- k = {'SD' : {}}
- state = [0,0,0]
-
- new_state = self.subject.internal_keras_eval_transition(
- state, self.dependent_var , self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(),
- self.serialized_weights, self.compile_params, 0,
- self.dist_key_mapping, 0, 4,
- self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
- agg_loss, agg_accuracy, image_count = new_state
-
- self.assertEqual(ending_image_count, image_count)
- # set_session must not get called for the first buffer
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # loss and accuracy should be unchanged
- self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
- self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
- # Clear session and sess.close must not get called for the first buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
- self.assertTrue(k['SD']['segment_model'])
-
- def _test_internal_keras_eval_transition_middle_buffer(self, is_platform_pg):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
- self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
-
- starting_image_count = len(self.dependent_var_int)
- ending_image_count = starting_image_count + len(self.dependent_var_int)
-
- # last iteration call
-
- k = {'SD' : {}}
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
-
- state = [self.loss * starting_image_count, self.accuracy * starting_image_count, starting_image_count]
- k['SD']['segment_model'] = self.model
- k['SD']['sess'] = Mock()
-
- new_state = self.subject.internal_keras_eval_transition(
- state, self.dependent_var , self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(),
- 'dummy_model_weights', None, 0,
- self.dist_key_mapping, 0, 4,
- self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
-
- agg_loss, agg_accuracy, image_count = new_state
-
- self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # loss and accuracy should be unchanged
- self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
- self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
- # Clear session and sess.close must not get called for the middle buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
-
- # Non-last iteration call
-
- self.subject.K.set_session.reset_mock()
- self.subject.K.clear_session.reset_mock()
- k = {'SD' : {}}
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
-
- state = [self.loss * starting_image_count, self.accuracy * starting_image_count, starting_image_count]
- k['SD']['segment_model'] = self.model
- k['SD']['sess'] = Mock()
-
- new_state = self.subject.internal_keras_eval_transition(
- state, self.dependent_var , self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(),
- 'dummy_model_weights', None, 0,
- self.dist_key_mapping, 0, 4,
- self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
- agg_loss, agg_accuracy, image_count = new_state
-
- self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # loss and accuracy should be unchanged
- self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
- self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
- # Clear session and sess.close must not get called for the middle buffer
- self.assertEqual(0, self.subject.K.clear_session.call_count)
-
- def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg):
- #TODO should we mock tensorflow's close_session and keras'
- # clear_session instead of mocking the function `K.clear_session`
- self.subject.K.set_session = Mock()
- self.subject.K.clear_session = Mock()
- self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
-
- starting_image_count = 2*len(self.dependent_var_int)
- ending_image_count = starting_image_count + len(self.dependent_var_int)
-
- k = {'SD' : {}}
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
-
- state = [self.loss * starting_image_count,
- self.accuracy * starting_image_count,
- starting_image_count]
-
- k['SD']['segment_model'] = self.model
- k['SD']['sess'] = Mock()
-
- new_state = self.subject.internal_keras_eval_transition(
- state, self.dependent_var , self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(),
- 'dummy_model_weights', None, 0,
- self.dist_key_mapping, 0, 4,
- self.total_images_per_seg, False, self.accessible_gpus_for_seg, True, **k)
- agg_loss, agg_accuracy, image_count = new_state
-
- self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # loss and accuracy should be unchanged
- self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
- self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
- # Clear session and sess.close must get called for the last buffer in gpdb,
- # but not in postgres
- self.assertEqual(1, self.subject.K.clear_session.call_count)
-
- # Non-final call
-
- self.subject.K.set_session.reset_mock()
- self.subject.K.clear_session.reset_mock()
- k = {'SD' : {}}
-
- self.subject.compile_and_set_weights(self.model, self.compile_params,
- '/cpu:0', self.serialized_weights)
-
- state = [self.loss * starting_image_count,
- self.accuracy * starting_image_count,
- starting_image_count]
-
- k['SD']['segment_model'] = self.model
- k['SD']['sess'] = Mock()
-
- new_state = self.subject.internal_keras_eval_transition(
- state, self.dependent_var , self.independent_var,
- self.dependent_var_shape, self.independent_var_shape,
- self.model.to_json(),
- 'dummy_model_weights', None, 0,
- self.dist_key_mapping, 0, 4,
- self.total_images_per_seg, False, self.accessible_gpus_for_seg, False, **k)
-
- agg_loss, agg_accuracy, image_count = new_state
-
- self.assertEqual(ending_image_count, image_count)
- # set_session is always called
- self.assertEqual(1, self.subject.K.set_session.call_count)
- # loss and accuracy should be unchanged
- self.assertAlmostEqual(self.loss * ending_image_count, agg_loss, 4)
- self.assertAlmostEqual(self.accuracy * ending_image_count, agg_accuracy, 4)
- # Clear session and sess.close must not get called in non-final iterations
- self.assertEqual(0, self.subject.K.clear_session.call_count)
-
- def test_internal_keras_eval_transition_first_buffer_pg(self):
- self._test_internal_keras_eval_transition_first_buffer(True)
-
- def test_internal_keras_eval_transition_first_buffer_gpdb(self):
- self._test_internal_keras_eval_transition_first_buffer(False)
-
- def test_internal_keras_eval_transition_middle_buffer_pg(self):
- self._test_internal_keras_eval_transition_middle_buffer(True)
-
- def test_internal_keras_eval_transition_middle_buffer_gpdb(self):
- self._test_internal_keras_eval_transition_middle_buffer(False)
-
- def test_internal_keras_eval_transition_last_buffer_pg(self):
- self._test_internal_keras_eval_transition_last_buffer(True)
-
- def test_internal_keras_eval_transition_last_buffer_gpdb(self):
- self._test_internal_keras_eval_transition_last_buffer(False)
-
def test_internal_keras_eval_merge(self):
image_count = self.total_images_per_seg[0]
state1 = [3.0*self.loss, 3.0*self.accuracy, image_count]
@@ -1919,7 +1792,16 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
self.module_patcher.stop()
if __name__ == '__main__':
- from tensorflow import keras
+ # Do not move any of the tensorflow imports outside of this block. This is
+ # because importing tensorflow.keras.models/layers changes the `__name__`
+ # variable because of which the if condition fails and the unit tests don't
+ # get run
+
+ # turn off verbose output
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+ import tensorflow as tf
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
unittest.main()