You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nk...@apache.org on 2019/04/15 22:17:11 UTC
[madlib] 03/05: DL: Rename get_device_name function
This is an automated email from the ASF dual-hosted git repository.
nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
commit ebec58271efc87d697c35639f9bbcb7aa47cd399
Author: Nikhil Kak <nk...@pivotal.io>
AuthorDate: Wed Apr 10 16:20:49 2019 -0700
DL: Rename get_device_name function
JIRA: MADLIB-1304
Renamed it because it was also doing a set operation
Closes #367
Co-authored-by: Jingyi Mei <jm...@pivotal.io>
---
src/ports/postgres/modules/deep_learning/madlib_keras.py_in | 5 ++---
src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in | 2 +-
.../modules/deep_learning/test/unit_tests/test_madlib_keras.py_in | 6 +++---
.../test/unit_tests/test_madlib_keras_serializer.py_in | 4 ++--
4 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index b883cac..83ca5a4 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -336,8 +336,7 @@ def fit_transition(state, ind_var, dep_var, current_seg_id, num_classes,
SD = kwargs['SD']
# Configure GPUs/CPUs
- device_name = get_device_name_for_keras(
- use_gpu, current_seg_id)
+ device_name = get_device_name_and_set_cuda_env(use_gpu, current_seg_id)
# Set up system if this is the first buffer on segment'
@@ -525,7 +524,7 @@ def evaluate1(schema_madlib, model_table, test_table, id_col, model_arch_table,
def internal_keras_evaluate(dependent_var, independent_var, model_architecture,
model_data, compile_params, use_gpu, seg, **kwargs):
- device_name = get_device_name_for_keras(use_gpu, seg)
+ device_name = get_device_name_and_set_cuda_env(use_gpu, seg)
model = model_from_json(model_architecture)
model_shapes = madlib_keras_serializer.get_model_shapes(model)
_, _, _, model_weights = madlib_keras_serializer.deserialize_weights(
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index a9ebcef..6ebf96e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -35,7 +35,7 @@ from utilities.utilities import _assert
#######################################################################
########### Keras specific functions #####
#######################################################################
-def get_device_name_for_keras(use_gpu, seg):
+def get_device_name_and_set_cuda_env(use_gpu, seg):
gpus_per_host = 4
if use_gpu:
device_name = '/gpu:0'
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 059bd11..8ca4958 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -151,12 +151,12 @@ class MadlibKerasFitTestCase(unittest.TestCase):
self.assertEqual(0, self.subject.clear_keras_session.call_count)
self.assertEqual(2, k['SD']['buffer_count'])
- def test_get_device_name_for_keras(self):
+ def test_get_device_name_and_set_cuda_env(self):
import os
- self.assertEqual('/gpu:0', self.subject.get_device_name_for_keras(
+ self.assertEqual('/gpu:0', self.subject.get_device_name_and_set_cuda_env(
True, 1))
self.assertEqual('1', os.environ["CUDA_VISIBLE_DEVICES"])
- self.assertEqual('/cpu:0', self.subject.get_device_name_for_keras(
+ self.assertEqual('/cpu:0', self.subject.get_device_name_and_set_cuda_env(
False, 1))
self.assertEqual('-1', os.environ["CUDA_VISIBLE_DEVICES"])
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in
index 6844327..8264800 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in
@@ -33,7 +33,7 @@ import plpy_mock as plpy
m4_changequote(`<!', `!>')
-class MadlibKerasHelperTestCase(unittest.TestCase):
+class MadlibSerializerTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {
@@ -139,7 +139,7 @@ class MadlibKerasHelperTestCase(unittest.TestCase):
self.assertEqual(np.array([0,1,2,1,3,4,5], dtype=np.float32).tostring(),
res)
-class MadlibSerializerTestCase(unittest.TestCase):
+class MadlibKerasHelperTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {