This is an automated email from the ASF dual-hosted git repository. nkak pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit ebec58271efc87d697c35639f9bbcb7aa47cd399 Author: Nikhil Kak <[email protected]> AuthorDate: Wed Apr 10 16:20:49 2019 -0700 DL: Rename get_device_name function JIRA: MADLIB-1304 Renamed it because it was also doing a set operation Closes #367 Co-authored-by: Jingyi Mei <[email protected]> --- src/ports/postgres/modules/deep_learning/madlib_keras.py_in | 5 ++--- src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in | 2 +- .../modules/deep_learning/test/unit_tests/test_madlib_keras.py_in | 6 +++--- .../test/unit_tests/test_madlib_keras_serializer.py_in | 4 ++-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in index b883cac..83ca5a4 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in @@ -336,8 +336,7 @@ def fit_transition(state, ind_var, dep_var, current_seg_id, num_classes, SD = kwargs['SD'] # Configure GPUs/CPUs - device_name = get_device_name_for_keras( - use_gpu, current_seg_id) + device_name = get_device_name_and_set_cuda_env(use_gpu, current_seg_id) # Set up system if this is the first buffer on segment' @@ -525,7 +524,7 @@ def evaluate1(schema_madlib, model_table, test_table, id_col, model_arch_table, def internal_keras_evaluate(dependent_var, independent_var, model_architecture, model_data, compile_params, use_gpu, seg, **kwargs): - device_name = get_device_name_for_keras(use_gpu, seg) + device_name = get_device_name_and_set_cuda_env(use_gpu, seg) model = model_from_json(model_architecture) model_shapes = madlib_keras_serializer.get_model_shapes(model) _, _, _, model_weights = madlib_keras_serializer.deserialize_weights( diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in index a9ebcef..6ebf96e 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in @@ -35,7 +35,7 @@ from utilities.utilities import _assert ####################################################################### ########### Keras specific functions ##### ####################################################################### -def get_device_name_for_keras(use_gpu, seg): +def get_device_name_and_set_cuda_env(use_gpu, seg): gpus_per_host = 4 if use_gpu: device_name = '/gpu:0' diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in index 059bd11..8ca4958 100644 --- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in @@ -151,12 +151,12 @@ class MadlibKerasFitTestCase(unittest.TestCase): self.assertEqual(0, self.subject.clear_keras_session.call_count) self.assertEqual(2, k['SD']['buffer_count']) - def test_get_device_name_for_keras(self): + def test_get_device_name_and_set_cuda_env(self): import os - self.assertEqual('/gpu:0', self.subject.get_device_name_for_keras( + self.assertEqual('/gpu:0', self.subject.get_device_name_and_set_cuda_env( True, 1)) self.assertEqual('1', os.environ["CUDA_VISIBLE_DEVICES"]) - self.assertEqual('/cpu:0', self.subject.get_device_name_for_keras( + self.assertEqual('/cpu:0', self.subject.get_device_name_and_set_cuda_env( False, 1)) self.assertEqual('-1', os.environ["CUDA_VISIBLE_DEVICES"]) diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in index 6844327..8264800 100644 --- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in +++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_serializer.py_in @@ -33,7 +33,7 @@ import plpy_mock as plpy m4_changequote(`<!', `!>') -class MadlibKerasHelperTestCase(unittest.TestCase): +class MadlibSerializerTestCase(unittest.TestCase): def setUp(self): self.plpy_mock = Mock(spec='error') patches = { @@ -139,7 +139,7 @@ class MadlibKerasHelperTestCase(unittest.TestCase): self.assertEqual(np.array([0,1,2,1,3,4,5], dtype=np.float32).tostring(), res) -class MadlibSerializerTestCase(unittest.TestCase): +class MadlibKerasHelperTestCase(unittest.TestCase): def setUp(self): self.plpy_mock = Mock(spec='error') patches = {
